こしあん
2019-08-14

PyTorchでweight clipping


Pocket
LINEで送る
Delicious にシェア

3.4k{icon} {views}


WGANの論文見てたらWeight Clippingしていたので、簡単な例を実装して実験してみました。かなり簡単にできます。それを見ていきましょう。

Weight Clippingとは

レイヤーの係数の値を一定範囲以内に収める手法。例えば、あるレイヤーが「-2, -1, 0, 1, 2」という5つの係数を持っていたとして、-1~1でWeight Clippingする場合、出力は「-1, -1, 0, 1, 1」になります。

WGANの論文では、このWeight ClippingがGANの損失関数のリプシッツ連続性を満たすために必要(ただし賢い方法ではない)、だと主張しています。

この記事はこちらのリポジトリをもとに実装しています。

Weight Clippingの実装

簡単なモデルを作ってみました。

import torch
from torch import nn

class SomeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Sequential(
            nn.Linear(3, 10),
            nn.BatchNorm1d(10),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.weights(inputs)

これでWeight Clipping前後で係数の値を比較します。

if __name__ == "__main__":
    model = SomeModel()
    print("Initial paramters")
    for p in model.parameters():
        print(p)
    # Weight clippng
    for p in model.parameters():
        p.data.clamp_(-0.1, 0.1)
    print("Weight clipped")
    for p in model.parameters():
        print(p)

「p.data.clamp_」でやっていることがWeight Clippingです。

結果

±0.1でクリップしてみました。

Initial paramters
Parameter containing:
tensor([[-0.0665,  0.3825,  0.4054],
        [-0.3007, -0.4494, -0.4472],
        [ 0.2556, -0.0655,  0.0282],
        [ 0.1952, -0.2858, -0.5615],
        [-0.5459,  0.5111, -0.0154],
        [-0.3570, -0.3204, -0.5725],
        [ 0.2857, -0.2683, -0.0341],
        [-0.4374, -0.3827,  0.1591],
        [-0.4890,  0.5398,  0.4199],
        [-0.2990, -0.4984,  0.5026]], requires_grad=True)
Parameter containing:
tensor([ 0.3708,  0.5396, -0.0534,  0.3967,  0.2119, -0.2312, -0.4485,  0.2014,
         0.5071, -0.1236], requires_grad=True)
Parameter containing:
tensor([0.3622, 0.8353, 0.6523, 0.5045, 0.7205, 0.0731, 0.2502, 0.2064, 0.7109,
        0.4635], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

Weight clipped
Parameter containing:
tensor([[-0.0665,  0.1000,  0.1000],
        [-0.1000, -0.1000, -0.1000],
        [ 0.1000, -0.0655,  0.0282],
        [ 0.1000, -0.1000, -0.1000],
        [-0.1000,  0.1000, -0.0154],
        [-0.1000, -0.1000, -0.1000],
        [ 0.1000, -0.1000, -0.0341],
        [-0.1000, -0.1000,  0.1000],
        [-0.1000,  0.1000,  0.1000],
        [-0.1000, -0.1000,  0.1000]], requires_grad=True)
Parameter containing:
tensor([ 0.1000,  0.1000, -0.0534,  0.1000,  0.1000, -0.1000, -0.1000,  0.1000,
         0.1000, -0.1000], requires_grad=True)
Parameter containing:
tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0731, 0.1000, 0.1000, 0.1000,
        0.1000], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

確かに±0.1の範囲外の値は頭打ちされているのがわかります。これで想定された実装になりました。

Pocket
LINEで送る
Delicious にシェア



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です