PyTorchでのConvTranspose2dのパラメーター設定について
VAE(Variational Auto Encoder)やGAN(Generative Adversarial Network)などで用いられるデコーダーで畳み込みの逆処理(Convtranspose2d)を使うことがあります。このパラメーター設定についてハマったので解説します。
目次
エンコーダーのConv2Dでダウンサンプリング
まず前提として、MaxPoolingなどのPoolingを使わなくても畳み込み(Conv2D)だけでダウンサンプリングはできます。GANで使われる手法ですが、CNNでも使えます。例えばMNISTで考えましょう。
入力:(-1, 1, 28, 28)+kernel=3の畳み込み
出力:(-1, 32, 14, 14)
どういう処理かというと、モノクロ(チャンネル数1)の28×28の画像を、32チャンネルの3×3の畳み込みを通して、14×14にダウンサンプリングするというものです。これを1つのConv2Dで定義するにはどうしたらよいでしょうか?答えはこうです。
import torch.nn as nn
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
これはわかりやすいと思います。PyTorchのpaddingは両側に付与するピクセル数、つまりpadding=1なら左右に1ピクセルずつ入れるということに注意してください。公式ドキュメントによると、出力の解像度の計算式は、
$$H_{out}=\Bigl\lfloor\frac{H_{in}+2\times\rm{padding}-\rm{kernel}}{\rm{stride}}+1 \Bigr\rfloor $$
で表されます(ドキュメントにはdilationがありますが、自分は使ったことないので省略しています)。上の式の場合、分母は「28+2×1-3 = 27」、strideの2で割り1を足すと14.5、小数点以下を切り捨てて「14」、ちゃんと正しく計算できていますね。
計算中に端数ができていることについてですが、(エンコーダー側の)CNNの場合はkernelを奇数とすることが普通です。奇数にすると中心のピクセルが生まれ、畳込みの計算が有効に機能しやすいと言われています。
ConvTranspose2dのパラメーター
ここからが本題です。上記はエンコーダー側の設定ですが、デコーダーでこの逆の処理をやりたい場合はどうすればよいでしょうか?つまり、
入力:(-1, 32, 14, 14)
出力:(-1, 1, 28, 28)
これをConvTrasnpose2Dで表現したい場合どうすればよいでしょうか?ConvTrasnpose2Dにはin_channels, out_channelsのほかに次のようなパラメーターがあります(dilationや解像度に関係なさそうなパラメーターは端折っています)。
kernel_size (int or tuple) – Size of the convolving kernel
stride (int or tuple, optional) – Stride of the convolution. Default: 1
padding (int or tuple, optional) – kernel_size – 1 – padding zero-padding will be added to both sides of each dimension in the input. Default: 0
output_padding (int or tuple, optional) – Additional size added to one side of each dimension in the output shape. Default: 0
paddingとoutput_paddingって何が違うんやという感じはしますが、Conv2Dよりパラメーターが多いです。ちなみに解像度は次の式で表されます。これを見るとだいぶ理解できると思います。
$$H_{out}=(H_{in}-1)\times\rm{stride}-2\times\rm{padding}+\rm{kernel}+\rm{output padding}$$
パラメーターが多すぎるので簡略化しましょう。一番わかりやすいのはstrideです。エンコーダーの場合は、stride=2つまりPoolingでkernel=2と同じようにダウンサンプリングしたので、デコーダーの場合もstride=2でアップサンプリングするのが妥当でしょう。つまり、
$$H_{out}=H_{in}\times\rm{stride} $$
という関係式が追加されます。これを上の式に代入すると、いくつか解はありますが、一番わかりやすい場合は「kernel=strideとする」ことだと思います。コードで書くとこうなります。
nn.ConvTranspose2d(32, 1, kernel_size=2, stride=2)
本当にこれで合っているか検算してみましょう。「(in – 1)×stride」の部分は「(14-1)×2=26」、paddingやoutput_paddingは0、これにkernel_size=2を足して28。合ってますね。
ちなみに「元のkernel=3の畳み込みはどこに消えたの?」と思われるかもしれませんが、デコーダーの場合はstrideと異なる奇数のカーネルを代入すると、paddingに2の係数がかかっている以上、どうしても端数が出てしまいます。つまり、strideと異なるカーネルサイズを入れたいときは、デコーダーでは偶数のカーネルを使ったほうが良さそうです。なので、エンコーダーのkernel=3の畳み込みはデコーダー側では再現しなくてもいいかなと自分は思います。GANの実装を見ていると、DiscriminatorのカーネルサイズとGeneratorのカーネルサイズが違うなんてことがよくあるので大丈夫だと思います。
強引に奇数のカーネルをデコーダーで実装したい場合、output_paddingを使えば再現できますが、自分がやった限りでは変なグリッドラインが入ってしまって出力画像がおかしくなりました。
画像を拡大して見ると、網目模様の灰色の線がわかると思います。output_paddingを外したら消えたのでこれが悪さをしています。
なので、自分はoutput_paddingは使わないほうがいいのではないかなと思います。paddingやoutput_paddingを使わずにstride=kernel_sizeとすることが一番わかりやすいですし、この場合は奇数の制約はありません(stride=kernel_sizeの条件があれば、奇数だろうが偶数だろうが式は成立する)。
以上です。ConvTrasnpose2DはConv2Dに比べて例が少ないので苦労しましたが、参考になれば幸いです。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー