PyTorchでSliced Wasserstein Distance (SWD)を実装してみました。オリジナルの実装はNumpyですが、これはPyTorchで実装しているので、GPU上で計算することができます。本来はGANの生成画像を評価するためのものですが、画像の分布不一致を見るためにも使うことができます。
目次
こちらのリポジトリにあります。
https://github.com/koshian2/swd-pytorch
PGGANの論文で使われている画像類似度の評価指標です。GANの評価指標の多く(Inception scoreやFID)が訓練済みInceptionモデルベースであるのに対し、SWDはInception依存ではありません。訓練済みInceptionモデルは大抵ImageNetベースなので、特徴量抽出がドメインによって得意だったり不得意だったりします。したがって、ImageNetによるInceptionが得意ではないドメイン(医用画像やアニメーション画像など)での類似度指標としては、Inception依存ではないほうが便利だったりします。
そして、大きな特徴として、SWDはラプラシアンピラミッドからパッチを切り出して距離を計算しているため、追加のラベルを参照しません。すべて教師なしでできます。
最も簡単なSWDの計算方法です。GPU上の計算となります。
import torch
from swd import swd
torch.manual_seed(123) # fix seed
x1 = torch.rand(1024, 3, 128, 128) # 1024 images, 3 chs, 128x128 resolution
x2 = torch.rand(1024, 3, 128, 128)
out = swd(x1, x2, device="cuda") # Fast estimation if device="cuda"
print(out) # tensor(53.6950)
swd
の引数の詳細情報です。
image1, image2
: 必須 4ランクのPyTorchテンソル。各テンソルは[N, ch, H, W]のshapeであること。正方形の画像(H=W)推奨。n_pyramid
: (任意) ラプラシアンピラミッドの層の数。もし None
(デフォルト : 論文と同じ)が指定されると、16×16の解像度に向かって再帰的にダウンサンプリングします。最小解像度のガウシアンピラミッドの層が、ラプラシアンピラミッドの配列に追加されるため、出力の層の数はn_pyramid + 1
になります。slice_size
: (任意) ラプラシアンピラミッドの各層から切り出すときのパッチサイズ。デフォルトは7(論文と同じ)。n_descriptors
: (任意) 画像あたりのDescriptor(パッチ)の数。デフォルトは128(論文と同じ)。n_repeat_projection
: (任意) ピクセル値を乱数で投射するときの反復回数。この値はGPUメモリに応じて指定してください デフォルトは128。「n_repeat_projection * proj_per_repeat = 512
」となることが推奨されます。この積の値が512であることは論文と同じですが、公式実装ではn_repeat_projection=4、proj_per_repeat=128としています。(この方法だと非常に多くのメモリが必要です…)。proj_per_repeat
: (任意) ピクセル値を乱数で投射するときの、ループ1回あたりの投射次元数。デフォルトは4。この値を大きくするとCUDA OOMを引き起こす恐れがあります。「n_repeat_projection * proj_per_repeat = 512
」となることが推奨されます。device
: (任意) "cpu"
or "cuda"
. GPUで計算したいときはcuda
を指定してください。デフォルトは"cpu"
。return_by_resolution
: (任意) Trueの場合、ラプラシアンピラミッドの各解像度ごとにSWDを返します。Falseの場合、解像度ごとに平均を取ったSWD値を返します。デフォルトはFalse。pyramid_batchsize
: (任意) ラプラシアンピラミッドを計算する際のバッチサイズ。この値を大きくするとCUDA OOMを引き起こす恐れがあります。この値を変えることでSWDの計算結果が変わるということはありません。デフォルトは128。いくつかSWDの特徴を確認するために実験をします。2つの実験を行います。
n_repeat_projection
とproj_per_repeat
値を変えてみる。ただし2つの値の積は512で固定。これをランダムノイズによるデータと、CIFAR-10で確かめます。結論から言うと、
となります。したがって、SWDを使う際は事前に計測するデータ数を固定しておくのが重要となります。n_repeat_projection
とproj_per_repeat
をメモリ使用量に応じて変えるのは差し支えありません。
各プロットは、ラプラシアンピラミッドの解像度ごとのSWDを示し、横軸はproj_per_repeat
の値、縦軸はSWDの値を示します。各ケース10回ずつ試行しました。
ランダムノイズの場合は、16384個(論文と同じ)の乱数データを作りました。解像度は128×128です。
訓練データ1万枚とテストデータ1万枚で比較しています。
いずれの場合でも、proj_per_repeat
を変えてもほとんど値が変わらないのがわかります(→メモリの応じて値を調整して良い)。
今度はSWDを計測する際のサンプル数を変えて実験してみます。サンプル数が1万なら、訓練1万枚・テスト1万枚の間のSWDを計測し、サンプル数が5000だったら訓練5000枚・テスト5000枚間のSWDを計測する…という具合に変えてみます。
各プロットは、ラプラシアンピラミッドの解像度ごとのSWDを示し、横軸はサンプル数、縦軸はSWDの値を示します。各ケース10回ずつ試行しました。
いずれの場合でも、サンプル数に応じてSWDが指数関数的に変わってしまうのが確認できます(重要)。これはSWDがWasserstein(Earth Mover)距離だからです。したがって、あるケースAではサンプル数1000で計測し、ケースBではサンプル数1万で計測するというのは適切ではありません。PGGANの論文ではサンプル数16384としていました。
データ数に応じて変えても良いかなとは思いますが、初めに値を決めたら統一して計測すべきかと思われます。
SWDは画像データ間の分布不一致を追加のラベルなしで計測できる指標、としても使うことができます。
2つの実験を行います。CIFAR-10の訓練・テストデータのSWDを以下の2条件で計測します。どちらのケースでもサンプル数は1万で統一します。
実験1,2ともに訓練データは不均衡だが、実験1より実験2のほうがより弱い不均衡/ミスマッチとなります。このような弱いミスマッチをSWDは検出できるでしょうか?
各プロットは、ラプラシアンピラミッドの解像度ごとのSWDを示し、横軸は削除したクラス数、縦軸はSWDの値を示します。各ケース10回ずつ試行しました。
クラスの削除数が多くなるほど(ミスマッチが強くなるほど)、SWDが大きくなることがわかります。
各プロットは、不均衡となっているクラスのインデックス別のSWDを示します。横軸は不均衡データ(訓練データA)の連結数、縦軸はSWDの値を示します。
先程のクラスを削除する実験よりかは弱いミスマッチですが、このような弱いミスマッチにおいてもSWDは有効に機能することが確認できます。
一つ懸念されることは、同様のミスマッチ検出が他の評価指標(SSIMなど)でも検出できないかということです。先程のクラスを削除する実験をSSIMでも行ってみます。
SSIMでは強いミスマッチに対してもうまく検出することができませんでした。
したがって、SWDはミスマッチを検出することに有効であるということが確認できたかと思います。