こしあん
2019-08-14

PyTorchでweight clipping

Pocket
LINEで送る


WGANの論文見てたらWeight Clippingしていたので、簡単な例を実装して実験してみました。かなり簡単にできます。それを見ていきましょう。

Weight Clippingとは

レイヤーの係数の値を一定範囲以内に収める手法。例えば、あるレイヤーが「-2, -1, 0, 1, 2」という5つの係数を持っていたとして、-1~1でWeight Clippingする場合、出力は「-1, -1, 0, 1, 1」になります。

WGANの論文では、このWeight ClippingがGANの損失関数のリプシッツ連続性を満たすために必要(ただし賢い方法ではない)、だと主張しています。

この記事はこちらのリポジトリをもとに実装しています。

Weight Clippingの実装

簡単なモデルを作ってみました。

import torch
from torch import nn

class SomeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Sequential(
            nn.Linear(3, 10),
            nn.BatchNorm1d(10),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.weights(inputs)

これでWeight Clipping前後で係数の値を比較します。

if __name__ == "__main__":
    model = SomeModel()
    print("Initial paramters")
    for p in model.parameters():
        print(p)
    # Weight clippng
    for p in model.parameters():
        p.data.clamp_(-0.1, 0.1)
    print("Weight clipped")
    for p in model.parameters():
        print(p)

「p.data.clamp_」でやっていることがWeight Clippingです。

結果

±0.1でクリップしてみました。

Initial paramters
Parameter containing:
tensor([[-0.0665,  0.3825,  0.4054],
        [-0.3007, -0.4494, -0.4472],
        [ 0.2556, -0.0655,  0.0282],
        [ 0.1952, -0.2858, -0.5615],
        [-0.5459,  0.5111, -0.0154],
        [-0.3570, -0.3204, -0.5725],
        [ 0.2857, -0.2683, -0.0341],
        [-0.4374, -0.3827,  0.1591],
        [-0.4890,  0.5398,  0.4199],
        [-0.2990, -0.4984,  0.5026]], requires_grad=True)
Parameter containing:
tensor([ 0.3708,  0.5396, -0.0534,  0.3967,  0.2119, -0.2312, -0.4485,  0.2014,
         0.5071, -0.1236], requires_grad=True)
Parameter containing:
tensor([0.3622, 0.8353, 0.6523, 0.5045, 0.7205, 0.0731, 0.2502, 0.2064, 0.7109,
        0.4635], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

Weight clipped
Parameter containing:
tensor([[-0.0665,  0.1000,  0.1000],
        [-0.1000, -0.1000, -0.1000],
        [ 0.1000, -0.0655,  0.0282],
        [ 0.1000, -0.1000, -0.1000],
        [-0.1000,  0.1000, -0.0154],
        [-0.1000, -0.1000, -0.1000],
        [ 0.1000, -0.1000, -0.0341],
        [-0.1000, -0.1000,  0.1000],
        [-0.1000,  0.1000,  0.1000],
        [-0.1000, -0.1000,  0.1000]], requires_grad=True)
Parameter containing:
tensor([ 0.1000,  0.1000, -0.0534,  0.1000,  0.1000, -0.1000, -0.1000,  0.1000,
         0.1000, -0.1000], requires_grad=True)
Parameter containing:
tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0731, 0.1000, 0.1000, 0.1000,
        0.1000], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

確かに±0.1の範囲外の値は頭打ちされているのがわかります。これで想定された実装になりました。

Related Posts

Numpyの配列をN個飛ばしで列挙する簡単な方法... Numpyの配列から奇数番目、偶数番目の要素を取り出したいときが稀によくあります。インデックスの配列を定義する必要があるのかなと思いますが、とても簡単な方法があります。それを見ていきましょう。 基本は「::スキップしたい間隔」 例として、0~9までの配列をとります。 >>>...
argparseに直接dictを読み込ませる怪しいやり方... argparseにコマンドライン引数ではなく、ファイルから読み込んだdictをオーバーラップさせる方法を試してみました。本来のargparseの使い方ではない怪しいやり方ですが、JSONやyamlファイルとの連携が可能なので便利ではないかなと思います。 注意 これは本来のargparseの使い...
PyTorchで複数のGPUで訓練するときのSync Batch Normalizationの必要性... PyTorchにはSync Batch Normalizationというレイヤーがありますが、これが通常のBatch Normzalitionと何が違うのか具体例を通じて見ていきます。また、通常のBatch Normは複数GPUでData Parallelするときにデメリットがあるのでそれも確認し...
TensorFlow/Kerasでネットワーク内でData Augmentationする方法... NumpyでData Augmentationするのが遅かったり、書くの面倒だったりすることありますよね。今回はNumpy(CPU)ではなく、ニューラルネットワーク側(GPU、TPU)でAugmetationをする方法を見ていきます。 こんなイメージ Numpy(CPU)でやる場合 Num...
OpenCVで画像を歪ませる方法 PythonでOpenCVを使い画像を歪ませる方法を考えます。アフィン変換というちょっと直感的に理解しにくいことをしますが、慣れればそこまで難しくはありません。ディープラーニングのData Augmentationにも使えます。 OpenCVでのアフィン変換のイメージ アフィン変換というと、ま...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です