こしあん
2018-12-10

PandasのDataFrameでグループ別にサンプルをN個抜き出す方法

Pocket
LINEで送る


「PandasでGroupbyでグルーピングしたはいんだけど、そこからグループ別にサンプルを1個、2個…と抜き出す、SQLでよくやるやつってどうやるんだっけ?」ということが気になったので、調べました。ちゃんとした方法があります。

例題

今、中国地方と四国地方の県と面積をDataFrameにしてみました。ここから、中国地方と四国地方の面積が小さい県を選んでみます。

import pandas as pd

prefs = {"pref_name":["鳥取", "島根", "岡山", "広島", "山口", "徳島", "香川", "愛媛", "高知"],
         "category":["中国", "中国", "中国", "中国", "中国", "四国", "四国", "四国", "四国"],
         "area":[3507, 6708, 7010, 8480, 6114, 4147, 1862, 5679, 7105]}
df = pd.DataFrame(prefs)
print(df)
  pref_name category  area
0        鳥取       中国  3507
1        島根       中国  6708
2        岡山       中国  7010
3        広島       中国  8480
4        山口       中国  6114
5        徳島       四国  4147
6        香川       四国  1862
7        愛媛       四国  5679
8        高知       四国  7105

グルーピングして抜き出すときはhead()

全体の流れとして、面積(area)でソートして、区分(category)でgroupbyすればいいというのはわかります。でもグルーピングした後抜き出すのはどうやるんしょう?

答えはpandas.core.groupby.GroupBy.headです。headなので、頭からサンプルをN個取ります。後ろから取りたいときはtail()を使います。

やってみましょう。

# sort
sorted = df.sort_values(["area"], 0, [True])
print(sorted)
# group by
grouped = sorted.groupby("category").head(1)
print(grouped)
  pref_name category  area
6        香川       四国  1862
0        鳥取       中国  3507
5        徳島       四国  4147
7        愛媛       四国  5679
4        山口       中国  6114
1        島根       中国  6708
2        岡山       中国  7010
8        高知       四国  7105
3        広島       中国  8480
  pref_name category  area
6        香川       四国  1862
0        鳥取       中国  3507

この通り、四国で最も狭い香川県と、中国で最も狭い鳥取県を抜き出すことができました。

後ろから抜き出したいときはtail()

さっきは面積が小さい県を抜き出しましたが、逆に最も面積が大きい県を抜き出してみましょう。sort_valuesを降順ソートにしてhead()でもいいですが、ここではtail()を使ってみます。

grouped = sorted.groupby("category").tail(1)
print(grouped)
  pref_name category  area
8        高知       四国  7105
3        広島       中国  8480

この通り、高知と広島が出てきました。

ソートをしてグルーピングすると順番がバラバラになるので、グループ間で直したい場合は再ソートが必要

ここがちょっと曲者ですが、head()に2以上を入れるとグループ別にバラバラの値が出てくるということがおきます。

# sort
sorted = df.sort_values(["area"], 0, [True])
# group by
grouped = sorted.groupby("category").head(2)
print(grouped)
  pref_name category  area
6        香川       四国  1862
0        鳥取       中国  3507
5        徳島       四国  4147
4        山口       中国  6114

なので、グループごとに一緒にして表示したいという場合はグループで再ソートする必要がありそうです(ここもっと上手いやり方あったら教えてください)。

# さらにソート
resort = grouped.sort_values(["category"], 0, [False])
print(resort)
  pref_name category  area
6        香川       四国  1862
5        徳島       四国  4147
0        鳥取       中国  3507
4        山口       中国  6114

このように、四国で面積の小さい2県:香川・徳島と、中国で面積の小さい2県:鳥取・山口を抽出することができました。愛媛県は山口県より面積が小さいですが(5679)、四国では3番目に小さいのでここには含まれていません。

実は公式ドキュメントに載ってた

この方法探してて、あんまいい方法ないなと思っていたら公式ドキュメントに載っていました。ただかなり下の方に載っている方法なので、気をつけてみないと見落とします。詳しく知りたかったら見てみてください。

Group By: split-apply-combine
https://pandas.pydata.org/pandas-docs/stable/groupby.html

Related Posts

TensorFlow Data Validationを使ったお手軽で強力な探索的データ解析... 特にテーブルデータで、実際の分析に入る前に欠損値やデータの分布の把握といった、探索的データ解析(EDA)というのは重要なプロセスになります。TensorFlow Data Validationというツールを使うとそれがたった数行で簡単にできます。その方法を紹介します。 探索的データ解析(EDA)...
データのお気持ちを考えながらData Augmentationする... Data Augmentationの「なぜ?」に注目しながら、エラー分析をしてCIFAR-10の精度向上を目指します。その結果、オレオレAugmentationながら、Wide ResNetで97.3%という、Auto Augmentとほぼ同じ(-0.1%)精度を出すことができました。 (※すご...
PythonのMessagePack-Numpyで独自のクラスをシリアライズする方法... MessagePackを使ってシリアライズを高速化したかったのですが、独自のクラスやネストされたオブジェクトについてシリアル化する方法が全然なかったので調べてみました。Numpyのシリアライズも使えるMessagePackの拡張版、MessagePack-Numpyを使って確かめます。 Mess...
PyTorch/TorchVisionで複数の入力をモデルに渡したいケース... PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を...
KerasのCallbackを使って継承したImageDataGeneratorに値が渡せるか確かめ... Kerasで前処理の内容をエポックごとに変えたいというケースがたまにあります。これを実装するとなると、CallbackからGeneratorに値を渡すというコードになりますが、これが本当にできるかどうか確かめてみました。 想定する状況 例えば、前処理で正則化に関係するData Augmenta...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です