こしあん
2018-12-10

PandasのDataFrameでグループ別にサンプルをN個抜き出す方法


「PandasでGroupbyでグルーピングしたはいんだけど、そこからグループ別にサンプルを1個、2個…と抜き出す、SQLでよくやるやつってどうやるんだっけ?」ということが気になったので、調べました。ちゃんとした方法があります。

例題

今、中国地方と四国地方の県と面積をDataFrameにしてみました。ここから、中国地方と四国地方の面積が小さい県を選んでみます。

import pandas as pd

prefs = {"pref_name":["鳥取", "島根", "岡山", "広島", "山口", "徳島", "香川", "愛媛", "高知"],
         "category":["中国", "中国", "中国", "中国", "中国", "四国", "四国", "四国", "四国"],
         "area":[3507, 6708, 7010, 8480, 6114, 4147, 1862, 5679, 7105]}
df = pd.DataFrame(prefs)
print(df)
  pref_name category  area
0        鳥取       中国  3507
1        島根       中国  6708
2        岡山       中国  7010
3        広島       中国  8480
4        山口       中国  6114
5        徳島       四国  4147
6        香川       四国  1862
7        愛媛       四国  5679
8        高知       四国  7105

グルーピングして抜き出すときはhead()

全体の流れとして、面積(area)でソートして、区分(category)でgroupbyすればいいというのはわかります。でもグルーピングした後抜き出すのはどうやるんしょう?

答えはpandas.core.groupby.GroupBy.headです。headなので、頭からサンプルをN個取ります。後ろから取りたいときはtail()を使います。

やってみましょう。

# sort
sorted = df.sort_values(["area"], 0, [True])
print(sorted)
# group by
grouped = sorted.groupby("category").head(1)
print(grouped)
  pref_name category  area
6        香川       四国  1862
0        鳥取       中国  3507
5        徳島       四国  4147
7        愛媛       四国  5679
4        山口       中国  6114
1        島根       中国  6708
2        岡山       中国  7010
8        高知       四国  7105
3        広島       中国  8480
  pref_name category  area
6        香川       四国  1862
0        鳥取       中国  3507

この通り、四国で最も狭い香川県と、中国で最も狭い鳥取県を抜き出すことができました。

後ろから抜き出したいときはtail()

さっきは面積が小さい県を抜き出しましたが、逆に最も面積が大きい県を抜き出してみましょう。sort_valuesを降順ソートにしてhead()でもいいですが、ここではtail()を使ってみます。

grouped = sorted.groupby("category").tail(1)
print(grouped)
  pref_name category  area
8        高知       四国  7105
3        広島       中国  8480

この通り、高知と広島が出てきました。

ソートをしてグルーピングすると順番がバラバラになるので、グループ間で直したい場合は再ソートが必要

ここがちょっと曲者ですが、head()に2以上を入れるとグループ別にバラバラの値が出てくるということがおきます。

# sort
sorted = df.sort_values(["area"], 0, [True])
# group by
grouped = sorted.groupby("category").head(2)
print(grouped)
  pref_name category  area
6        香川       四国  1862
0        鳥取       中国  3507
5        徳島       四国  4147
4        山口       中国  6114

なので、グループごとに一緒にして表示したいという場合はグループで再ソートする必要がありそうです(ここもっと上手いやり方あったら教えてください)。

# さらにソート
resort = grouped.sort_values(["category"], 0, [False])
print(resort)
  pref_name category  area
6        香川       四国  1862
5        徳島       四国  4147
0        鳥取       中国  3507
4        山口       中国  6114

このように、四国で面積の小さい2県:香川・徳島と、中国で面積の小さい2県:鳥取・山口を抽出することができました。愛媛県は山口県より面積が小さいですが(5679)、四国では3番目に小さいのでここには含まれていません。

実は公式ドキュメントに載ってた

この方法探してて、あんまいい方法ないなと思っていたら公式ドキュメントに載っていました。ただかなり下の方に載っている方法なので、気をつけてみないと見落とします。詳しく知りたかったら見てみてください。

Group By: split-apply-combine
https://pandas.pydata.org/pandas-docs/stable/groupby.html

Related Posts

Python(Numpy)で画像を水平反転する方法:Data Augmentation向け... OpenCVを使わずに単純に画像を左右反転(水平反転)する方法を考えます。ディープラーニングでデータのジェネレーターを自分で実装した場合、Data Augmentationを組み込む際にも必要になります。それを見ていきましょう。 左右反転自体は実は簡単 例えばNumpyの行列を左右反転させてみ...
keras_preprocessingを使ってお手軽に画像を回転させる方法... Data Augmentationで画像を回転させたいことがあります。画像の回転は一般に「アフィン変換」と呼ばれる操作で、OpenCVやPillowのライブラリを使えば簡単にできるのですが、Numpy配列に対して1から書くとかなりめんどいのです。Kerasが裏で使っているkeras_preproc...
Chainerで画像の前処理やDataAugmentationをしたいときはDatasetMixin... Chainerにはデフォルトでランダムクロップや標準化といった、画像の前処理やDataAugmentation用の関数が用意されていません。別途のChainer CVというライブラリを使う方法もありますが、chainer.dataset.DatasetMixinを継承させて独自のデータ・セットを定...
pandasでグループ別に統計量やヒストグラムを表示する方法... 初投稿です。この記事では、pandasでグループ別に基本統計量(describe)をする方法を紹介します。 テストデータ 以下のようなデータを想定します。ある学校の3つのクラスでテストをしてみました。 A組は40人、平均点は60点、標準偏差は10(分布は正規分布に従うものとする) B組は...
PyTorchでサイズの異なる画像を読み込む方法... 実際の画像判定では、MNISTやCIFARのようにサイズが完全に整形されたデータはなかなか少ないです。例えばサイズが横幅は一定でも縦幅が異なっていたりするケースがあります。訓練画像間でサイズが異なる場合、そのまま読み込みするとエラーになります。その解決法を示します。 transforms.Ra...

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です