こしあん
2018-09-11

PyTorchでのConvTranspose2dのパラメーター設定について

Pocket
LINEで送る

VAE(Variational Auto Encoder)やGAN(Generative Adversarial Network)などで用いられるデコーダーで畳み込みの逆処理(Convtranspose2d)を使うことがあります。このパラメーター設定についてハマったので解説します。

エンコーダーのConv2Dでダウンサンプリング

まず前提として、MaxPoolingなどのPoolingを使わなくても畳み込み(Conv2D)だけでダウンサンプリングはできます。GANで使われる手法ですが、CNNでも使えます。例えばMNISTで考えましょう。

入力:(-1, 1, 28, 28)+kernel=3の畳み込み
出力:(-1, 32, 14, 14)

どういう処理かというと、モノクロ(チャンネル数1)の28×28の画像を、32チャンネルの3×3の畳み込みを通して、14×14にダウンサンプリングするというものです。これを1つのConv2Dで定義するにはどうしたらよいでしょうか?答えはこうです。

import torch.nn as nn
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)

これはわかりやすいと思います。PyTorchのpaddingは両側に付与するピクセル数、つまりpadding=1なら左右に1ピクセルずつ入れるということに注意してください。公式ドキュメントによると、出力の解像度の計算式は、

$$H_{out}=\Bigl\lfloor\frac{H_{in}+2\times\rm{padding}-\rm{kernel}}{\rm{stride}}+1 \Bigr\rfloor $$

で表されます(ドキュメントにはdilationがありますが、自分は使ったことないので省略しています)。上の式の場合、分母は「28+2×1-3 = 27」、strideの2で割り1を足すと14.5、小数点以下を切り捨てて「14」、ちゃんと正しく計算できていますね。

計算中に端数ができていることについてですが、(エンコーダー側の)CNNの場合はkernelを奇数とすることが普通です。奇数にすると中心のピクセルが生まれ、畳込みの計算が有効に機能しやすいと言われています。

ConvTranspose2dのパラメーター

ここからが本題です。上記はエンコーダー側の設定ですが、デコーダーでこの逆の処理をやりたい場合はどうすればよいでしょうか?つまり、

入力:(-1, 32, 14, 14)
出力:(-1, 1, 28, 28)

これをConvTrasnpose2Dで表現したい場合どうすればよいでしょうか?ConvTrasnpose2Dにはin_channels, out_channelsのほかに次のようなパラメーターがあります(dilationや解像度に関係なさそうなパラメーターは端折っています)。

kernel_size (int or tuple) – Size of the convolving kernel
stride (int or tuple, optional) – Stride of the convolution. Default: 1
padding (int or tuple, optional) – kernel_size – 1 – padding zero-padding will be added to both sides of each dimension in the input. Default: 0
output_padding (int or tuple, optional) – Additional size added to one side of each dimension in the output shape. Default: 0

公式ドキュメントより

paddingとoutput_paddingって何が違うんやという感じはしますが、Conv2Dよりパラメーターが多いです。ちなみに解像度は次の式で表されます。これを見るとだいぶ理解できると思います。

$$H_{out}=(H_{in}-1)\times\rm{stride}-2\times\rm{padding}+\rm{kernel}+\rm{output padding}$$

パラメーターが多すぎるので簡略化しましょう。一番わかりやすいのはstrideです。エンコーダーの場合は、stride=2つまりPoolingでkernel=2と同じようにダウンサンプリングしたので、デコーダーの場合もstride=2でアップサンプリングするのが妥当でしょう。つまり、

$$H_{out}=H_{in}\times\rm{stride} $$

という関係式が追加されます。これを上の式に代入すると、いくつか解はありますが、一番わかりやすい場合は「kernel=strideとする」ことだと思います。コードで書くとこうなります。

nn.ConvTranspose2d(32, 1, kernel_size=2, stride=2)

本当にこれで合っているか検算してみましょう。「(in – 1)×stride」の部分は「(14-1)×2=26」、paddingやoutput_paddingは0、これにkernel_size=2を足して28。合ってますね。

ちなみに「元のkernel=3の畳み込みはどこに消えたの?」と思われるかもしれませんが、デコーダーの場合はstrideと異なる奇数のカーネルを代入すると、paddingに2の係数がかかっている以上、どうしても端数が出てしまいます。つまり、strideと異なるカーネルサイズを入れたいときは、デコーダーでは偶数のカーネルを使ったほうが良さそうです。なので、エンコーダーのkernel=3の畳み込みはデコーダー側では再現しなくてもいいかなと自分は思います。GANの実装を見ていると、DiscriminatorのカーネルサイズとGeneratorのカーネルサイズが違うなんてことがよくあるので大丈夫だと思います。

強引に奇数のカーネルをデコーダーで実装したい場合、output_paddingを使えば再現できますが、自分がやった限りでは変なグリッドラインが入ってしまって出力画像がおかしくなりました。

画像を拡大して見ると、網目模様の灰色の線がわかると思います。output_paddingを外したら消えたのでこれが悪さをしています。

なので、自分はoutput_paddingは使わないほうがいいのではないかなと思います。paddingやoutput_paddingを使わずにstride=kernel_sizeとすることが一番わかりやすいですし、この場合は奇数の制約はありません(stride=kernel_sizeの条件があれば、奇数だろうが偶数だろうが式は成立する)。

以上です。ConvTrasnpose2DはConv2Dに比べて例が少ないので苦労しましたが、参考になれば幸いです。

Related Posts

PyTorchでweight clipping WGANの論文見てたらWeight Clippingしていたので、簡単な例を実装して実験してみました。かなり簡単にできます。それを見ていきましょう。 Weight Clippingとは レイヤーの係数の値を一定範囲以内に収める手法。例えば、あるレイヤーが「-2, -1, 0, 1, 2」という...
PyTorchでGANの訓練をするときにrequires_grad(trainable)の変更はいる... PyTorchでGANのある実装を見ていたときに、requires_gradの変更している実装を見たことがあります。Kerasだとtrainableの明示的な変更はいるんで、もしかしてPyTorchでもいるんじゃないかな?と疑問になったので、確かめてみました。 requires_gradの変更とは...
GANでGeneratorの損失関数をmin(log(1-D))からmaxlog Dにした場合の実験... GANの訓練をうまくいくためのTipとしてよく引用される、How to train GANの中から、Generatorの損失関数をmin(log(1-D))からmaxlog Dにした場合を実験してみました。その結果、損失結果を変更しても出力画像のクォリティーには大して差が出ないことがわかりました。...
pix2pixを1から実装して白黒画像をカラー化してみた(PyTorch)... pix2pixによる白黒画像のカラー化を1から実装します。PyTorchで行います。かなり自然な色付けができました。pix2pixはGANの中でも理論が単純なのにくわえ、学習も比較的安定しているので結構おすすめです。 はじめに PyTorchでDCGANができたので、今回はpix2pixをやり...
PyTorchでDCGANやってみた PyTorchでDCGANをやってみました。MNISTとCIFAR-10、STL-10を動かしてみましたがかなり簡単にできました。訓練時間もそこまで長くはないので結構手軽に遊べます。 はじめに PyTorchでDCGANやってみました。コードはほとんどこの記事のコピペです。MNISTとCIFA...
Pocket
LINEで送る

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です