NumPy関数だけでTopKを求め、多次元のインデックスをスライスするための方法

NumPy関数を使って多次元配列のTopKを求める方法を検証します。PyTorchの関数を使えば簡単にできますが、NumPyだけで行う場合は工夫が必要です。いつも忘れるので自分用忘備録に
目次
PyTorchだと簡単だけどNumPyだと一筋縄ではないかないTopK
TopKというのは、上位○個の要素を求めてねという処理です。PyTorchだとこれはtorch.topkという専用の関数を使えば楽勝で、
torch.manual_seed(1234)
x = torch.randn(10, 3)
print(x)
values, indices = torch.topk(x, 2, dim=0)
print(values)
print(indices)
tensor([[-0.1117, -0.4966,  0.1631],
        [-0.8817,  0.0539,  0.6684],
        [-0.0597, -0.4675, -0.2153],
        [ 0.8840, -0.7584, -0.3689],
        [-0.3424, -1.4020,  1.4255],
        [ 0.7987, -1.4949,  0.8810],
        [-1.1786, -0.9340, -0.5675],
        [-0.2772, -0.4030,  0.4195],
        [ 0.9380,  0.0078, -0.3139],
        [-1.1567,  1.8409, -1.0174]])
tensor([[0.9380, 1.8409, 1.4255],
        [0.8840, 0.0539, 0.8810]])
tensor([[8, 9, 4],
        [3, 1, 5]])
列ごとに独立して上位2個の要素を求めているのがポイントです。
これをNumPyでも実装したいのですが、NumPyにはTopKの関数がありません。さてどう実装するでしょうか?
np.argpartition+np.take_along_axisでやる
こちらの記事を参考にしました。
np.argpartition
np.argpartitionという関数を初めて知りました。簡単にいうと、np.argsortの完全にソートしない場合で、大まかなブロック単位で大小判定できていれば良しとするときに使える関数です。具体的にはTopKをとって、合計や平均をとるときに使えます(足し算は順序入れ替えても変わらないので)。完全にソートしなくていいので、計算量的に軽いということでしょう。
この関数理解するのが難しいので、1次元の例で説明します。
x = np.array([6, 3, 1, 8, 0, -5, 2, 0])
print("Ground Truth", np.argsort(x))
for i in range(5):
    print(f"Arg partition ", i, np.argpartition(x, i))
np.argsortは文字通り完全にソートした場合でGround Truthにあたります。次にnp.argpartitionの第2引数を変えた場合に、どのように出力が変化するかを確かめます。
Ground Truth [5 4 7 2 6 1 0 3]
Arg partition  0 [5 1 2 3 4 0 6 7]
Arg partition  1 [5 4 2 3 1 0 6 7]
Arg partition  2 [5 4 7 3 1 0 6 2]
Arg partition  3 [4 5 7 2 1 6 0 3]
Arg partition  4 [4 5 7 2 6 1 0 3]
Ground Truthとの関係に注目しましょう。第二引数のインデックスの場所に仕切りをおきます
* Arg partition 0 [5 | 1 2 3 4 0 6 7]
* Arg partition 1 [5 4 | 2 3 1 0 6 7]
* Arg partition 2 [5 4 7 | 3 1 0 6 2]
* Arg partition 3 [4 5 7 2 | 1 6 0 3]
* Arg partition 4 [4 5 7 2 6 | 1 0 3]
np.argpartitionのお気持ちは、「仕切りの左と右で、個々のインデックスの順序は違うけど、仕切りレベルでの大小は会っていますよね?」ということです。
例えば0のケースでは、「5」が最初にきているのはあっています。その他の順番はいい加減です。
例えば1のケースでは、仕切りの左側「5、4」がきているのはあっています。ここは「4、5」がくることもあります。仕切りの左側と右側での大小関係のみみていて、仕切り内の順序はみていません。
わかりやすいのが3のケースで、仕切りの左側は「4、5、7、2」となっていますが、Ground Truthは「5、4、7、2」です。順序は異なっていますが、仕切りの左側の構成要素は同じです。
「partition」というのは仕切りという意味なので、文字通りの意味ですね。
負のインデックスにすると降順にできる
第二引数に負のインデックスを入れると降順にできます。やってみましょう。
print("Ground Truth", np.argsort(x))
for i in range(6):
    print(f"Arg partition ", -i, np.argpartition(x, -i))
print("---")
print("Arg sort reverse", np.argsort(x)[::-1])
出力にこちらで仕切りを入れています。
Ground Truth [5 4 7 2 6 1 0 3]
Arg partition  0 [5 | 1 2 3 4 0 6 7]
Arg partition  -1 [6 7 2 1 4 5 0 | 3]
Arg partition  -2 [6 7 2 1 4 5 | 0 3]
Arg partition  -3 [4 5 7 2 6 | 1 0 3]
Arg partition  -4 [4 5 7 2 | 6 1 0 3]
Arg partition  -5 [4 5 7 | 2 1 6 0 3]
---
Arg sort reverse [3 0 1 6 2 7 4 5]
このような感じになります。同一値のときにインデックスがどうなるのか興味はありますが、ざっくりと求めたいときは便利でしょう。
これが何に使えるか
TopKをとったあと合計や平均をとりたいときに使えます。足し算は交換法則が成り立つので、
$$x_1 + x_2 + \cdots + x_n = x_5 + x_ 7 + \cdots + x_1$$
のように順序がランダムでも結果は一緒です。
2次元以上の場合
2次元以上の場合は返ってくるインデックスが多次元配列になります。
np.random.seed(1234)
x = np.random.randn(10, 3)
print(x)
print("Ground Truth")
print(np.argsort(x, axis=0))
print("Arg Parition axis=0, 3")
print(np.argpartition(x, 3, axis=0))
仕切り線をこちらで追加しました。
[[ 4.71435164e-01 -1.19097569e+00  1.43270697e+00]
 [-3.12651896e-01 -7.20588733e-01  8.87162940e-01]
 [ 8.59588414e-01 -6.36523504e-01  1.56963721e-02]
 [-2.24268495e+00  1.15003572e+00  9.91946022e-01]
 [ 9.53324128e-01 -2.02125482e+00 -3.34077366e-01]
 [ 2.11836468e-03  4.05453412e-01  2.89091941e-01]
 [ 1.32115819e+00 -1.54690555e+00 -2.02646325e-01]
 [-6.55969344e-01  1.93421376e-01  5.53438911e-01]
 [ 1.31815155e+00 -4.69305285e-01  6.75554085e-01]
 [-1.81702723e+00 -1.83108540e-01  1.05896919e+00]]
Ground Truth
[[3 4 4]
 [9 6 6]
 [7 0 2]
 [1 1 5]
 [5 2 7]
 [0 8 8]
 [2 9 1]
 [4 7 3]
 [8 5 9]
 [6 3 0]]
Arg Parition axis=0, 3
[[3 6 4]
 [7 4 6]
 [9 0 2]
----------
 [1 1 5]
 [5 2 7]
 [0 8 8]
 [6 9 1]
 [2 7 3]
 [8 5 9]
 [4 3 0]]
こういう状態です。Argsortの結果と見比べてください。
np.take_along_axis
多次元配列の場合のインデックスの扱いがまた曲者です。単にスライスするとShapeが想定していないものになります
indices = np.argpartition(x, 3, axis=0)
print(x[indices].shape) # (10, 3, 3)
3階テンソルになってしまいました。これはおかしいです。
「インデックスはaxis=0方向にとってほしい」ということなので、スライスではなく、np.take_along_axisという関数を使います。
y = np.take_along_axis(x, indices, axis=0)
print(y)
[[-2.24268495e+00 -1.54690555e+00 -3.34077366e-01]
 [-6.55969344e-01 -2.02125482e+00 -2.02646325e-01]
 [-1.81702723e+00 -1.19097569e+00  1.56963721e-02]
 [-3.12651896e-01 -7.20588733e-01  2.89091941e-01]
 [ 2.11836468e-03 -6.36523504e-01  5.53438911e-01]
 [ 4.71435164e-01 -4.69305285e-01  6.75554085e-01]
 [ 1.32115819e+00 -1.83108540e-01  8.87162940e-01]
 [ 8.59588414e-01  1.93421376e-01  9.91946022e-01]
 [ 1.31815155e+00  4.05453412e-01  1.05896919e+00]
 [ 9.53324128e-01  1.15003572e+00  1.43270697e+00]]
上の各3行分が、各行の最小値側の3要素です。argpartitionを使っているので順番は保証されません。
結局TopKってどうやるの?
K=2の場合
np.random.seed(1234)
x = np.random.randn(10, 3)
print(x)
indices = np.argpartition(x, -2, axis=0)[-2:, :]
topk = np.take_along_axis(x, indices, axis=0)
print(topk)
[[ 4.71435164e-01 -1.19097569e+00  1.43270697e+00]
 [-3.12651896e-01 -7.20588733e-01  8.87162940e-01]
 [ 8.59588414e-01 -6.36523504e-01  1.56963721e-02]
 [-2.24268495e+00  1.15003572e+00  9.91946022e-01]
 [ 9.53324128e-01 -2.02125482e+00 -3.34077366e-01]
 [ 2.11836468e-03  4.05453412e-01  2.89091941e-01]
 [ 1.32115819e+00 -1.54690555e+00 -2.02646325e-01]
 [-6.55969344e-01  1.93421376e-01  5.53438911e-01]
 [ 1.31815155e+00 -4.69305285e-01  6.75554085e-01]
 [-1.81702723e+00 -1.83108540e-01  1.05896919e+00]]
[[1.31815155 0.40545341 1.05896919]
 [1.32115819 1.15003572 1.43270697]]
K=5の場合、
[[ 0.47143516 -0.46930528  0.28909194]
 [ 0.85958841 -0.18310854 -0.20264632]
 [ 0.95332413  0.19342138  0.55343891]
 [ 1.31815155  0.40545341  1.05896919]
 [ 1.32115819  1.15003572  1.43270697]]
となります。
値の妥当性
そもそもこれうまくいっているのでしょうか? argsortした場合と、argpartitionの場合で、TopKとったあとのaxis=0の平均をみてみます。
result = []
for trial in range(100):
    row = np.random.randint(low=100, high=1000)
    col = np.random.randint(low=100, high=1000)
    x = np.random.randn(row, col)
    k = np.random.randint(low=row//20, high=row*2//3)
    indices = np.argpartition(x, -k, axis=0)[-k:, :]
    topk = np.take_along_axis(x, indices, axis=0)
    result_a = np.mean(topk, axis=0)
    indices = np.argsort(x, axis=0)[-k:, :]
    topk = np.take_along_axis(x, indices, axis=0)
    result_b = np.mean(topk, axis=0)
    flag = np.allclose(result_a, result_b)
    result.append(flag)
print(np.sum(result))
それぞれ「np.argpartition」でとったTopK、「np.argsort」でとったTopKです。この2つのmeanの結果がすべて同じなら出力が100(100回試行)になります。結果は、
100
となり、無事一致しました。
パフォーマンス比較
np.argpartitionはおそらく速いだろうということでしたが、実際に検証してみます。
環境:Colab CPU
コード
行数とTopKの割合を変えて、各パターン100回試行した平均値を計算します。
import time
print("argpartition")
for row in [100, 1000, 10000, 100000]:
    x = np.random.randn(row, 1000)
    print("n_row", row)
    for k_ratio in [100, 10, 2]:
        start_time = time.time()
        for trial in range(100):
            k = x.shape[0]//k_ratio
            y = np.argpartition(x, -k, axis=0)[-k:,:]
        elapsed = time.time() - start_time
        print(k_ratio, elapsed/100)
print("argsort")
for row in [100, 1000, 10000, 100000]:
    x = np.random.randn(row, 1000)
    print("n_row", row)
    for k_ratio in [100, 10, 2]:
        start_time = time.time()
        for trial in range(100):
            k = x.shape[0]//k_ratio
            y = np.argsort(x, axis=0)[-k:,:]
        elapsed = time.time() - start_time
        print(k_ratio, elapsed/100)
np.argpartitionの場合
| 行数/K_ratio | 100 | 10 | 2 | 
|---|---|---|---|
| 100 | 0.00062 | 0.00211 | 0.00246 | 
| 1000 | 0.03382 | 0.03386 | 0.03826 | 
| 10000 | 0.25718 | 0.27161 | 0.30098 | 
| 100000 | 5.22519 | 5.31681 | 5.61820 | 
np.argsortの場合
| 行数/K_ratio | 100 | 10 | 2 | 
|---|---|---|---|
| 100 | 0.00458 | 0.00444 | 0.00453 | 
| 1000 | 0.08250 | 0.08663 | 0.08135 | 
| 10000 | 1.24387 | 1.23313 | 1.24913 | 
| 100000 | 16.00673 | 15.58112 | 15.85248 | 
倍率
| 倍率 | 100 | 10 | 2 | 
|---|---|---|---|
| 100 | 7.36275 | 2.10258 | 1.84095 | 
| 1000 | 2.43963 | 2.55878 | 2.12604 | 
| 10000 | 4.83663 | 4.54015 | 4.15025 | 
| 100000 | 3.06338 | 2.93054 | 2.82163 | 
やはり計算量のオーダーが異なり、np.argpartitionのほうが3倍~7倍程度速いです。np.argpartitionのいいところは、TopKのKの割合が少なければ、高速化がかかってくれるということです。argsortの場合は全体をソートしてしまうため、Kの割合に関係なく計算量は一定となります。
1万個ぐらいの要素に対してTopKをとる操作は、少し混みいったモデルを作ると普通に出てくるので、使いこなすと便利だと思います。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー
 
       