こしあん
2019-08-14

PyTorchでweight clipping

Pocket
LINEで送る


WGANの論文見てたらWeight Clippingしていたので、簡単な例を実装して実験してみました。かなり簡単にできます。それを見ていきましょう。

Weight Clippingとは

レイヤーの係数の値を一定範囲以内に収める手法。例えば、あるレイヤーが「-2, -1, 0, 1, 2」という5つの係数を持っていたとして、-1~1でWeight Clippingする場合、出力は「-1, -1, 0, 1, 1」になります。

WGANの論文では、このWeight ClippingがGANの損失関数のリプシッツ連続性を満たすために必要(ただし賢い方法ではない)、だと主張しています。

この記事はこちらのリポジトリをもとに実装しています。

Weight Clippingの実装

簡単なモデルを作ってみました。

import torch
from torch import nn

class SomeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Sequential(
            nn.Linear(3, 10),
            nn.BatchNorm1d(10),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.weights(inputs)

これでWeight Clipping前後で係数の値を比較します。

if __name__ == "__main__":
    model = SomeModel()
    print("Initial paramters")
    for p in model.parameters():
        print(p)
    # Weight clippng
    for p in model.parameters():
        p.data.clamp_(-0.1, 0.1)
    print("Weight clipped")
    for p in model.parameters():
        print(p)

「p.data.clamp_」でやっていることがWeight Clippingです。

結果

±0.1でクリップしてみました。

Initial paramters
Parameter containing:
tensor([[-0.0665,  0.3825,  0.4054],
        [-0.3007, -0.4494, -0.4472],
        [ 0.2556, -0.0655,  0.0282],
        [ 0.1952, -0.2858, -0.5615],
        [-0.5459,  0.5111, -0.0154],
        [-0.3570, -0.3204, -0.5725],
        [ 0.2857, -0.2683, -0.0341],
        [-0.4374, -0.3827,  0.1591],
        [-0.4890,  0.5398,  0.4199],
        [-0.2990, -0.4984,  0.5026]], requires_grad=True)
Parameter containing:
tensor([ 0.3708,  0.5396, -0.0534,  0.3967,  0.2119, -0.2312, -0.4485,  0.2014,
         0.5071, -0.1236], requires_grad=True)
Parameter containing:
tensor([0.3622, 0.8353, 0.6523, 0.5045, 0.7205, 0.0731, 0.2502, 0.2064, 0.7109,
        0.4635], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

Weight clipped
Parameter containing:
tensor([[-0.0665,  0.1000,  0.1000],
        [-0.1000, -0.1000, -0.1000],
        [ 0.1000, -0.0655,  0.0282],
        [ 0.1000, -0.1000, -0.1000],
        [-0.1000,  0.1000, -0.0154],
        [-0.1000, -0.1000, -0.1000],
        [ 0.1000, -0.1000, -0.0341],
        [-0.1000, -0.1000,  0.1000],
        [-0.1000,  0.1000,  0.1000],
        [-0.1000, -0.1000,  0.1000]], requires_grad=True)
Parameter containing:
tensor([ 0.1000,  0.1000, -0.0534,  0.1000,  0.1000, -0.1000, -0.1000,  0.1000,
         0.1000, -0.1000], requires_grad=True)
Parameter containing:
tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0731, 0.1000, 0.1000, 0.1000,
        0.1000], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

確かに±0.1の範囲外の値は頭打ちされているのがわかります。これで想定された実装になりました。

Related Posts

PyTorch/TorchVisionで複数の入力をモデルに渡したいケース... PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を...
Kerasのジェネレーターでサンプルが列挙される順番について... Kerasの(カスタム)ジェネレーターでサンプルがどの順番で呼び出されるか、1ループ終わったあとにどういう処理がなされるのか調べてみました。ジェネレーターを自分で定義するとモデルの表現の幅は広がるものの、バグが起きやすくなるので「本当に順番が保証されるのか」や「ハマりどころ」を確認します。 0~...
pipからインストールしたTorchVisionにImageNetがないときの対応... TorchVisionの公式ドキュメントにはImageNetが利用できるとの記述がありますが、pipからインストールするとImageNetのモジュール自体がないことがあります。TorchVisionにImageNetのモジュールを手動でインストールする方法を解説します。 発生状況 Pytho...
OpenCVで作成した動画がブラウザで正常に表示できない場合の解決法... OpenCVで作成した動画をサイトで表示する場合、ローカルで再生できていても、ブラウザ上では突然プレビューがでなり、ハマることがあります。原因の特定が難しい現象ですが、動画を作成する際にH.264形式でエンコードするとうまくいきました。その方法を解説します。 MPV4は手軽だが… OpenCV...
Chainerで画像の前処理やDataAugmentationをしたいときはDatasetMixin... Chainerにはデフォルトでランダムクロップや標準化といった、画像の前処理やDataAugmentation用の関数が用意されていません。別途のChainer CVというライブラリを使う方法もありますが、chainer.dataset.DatasetMixinを継承させて独自のデータ・セットを定...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です