PyTorchでweight clipping
Posted On 2019-08-14
WGANの論文見てたらWeight Clippingしていたので、簡単な例を実装して実験してみました。かなり簡単にできます。それを見ていきましょう。
目次
Weight Clippingとは
レイヤーの係数の値を一定範囲以内に収める手法。例えば、あるレイヤーが「-2, -1, 0, 1, 2」という5つの係数を持っていたとして、-1~1でWeight Clippingする場合、出力は「-1, -1, 0, 1, 1」になります。
WGANの論文では、このWeight ClippingがGANの損失関数のリプシッツ連続性を満たすために必要(ただし賢い方法ではない)、だと主張しています。
この記事はこちらのリポジトリをもとに実装しています。
Weight Clippingの実装
簡単なモデルを作ってみました。
import torch
from torch import nn
class SomeModel(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Sequential(
nn.Linear(3, 10),
nn.BatchNorm1d(10),
nn.Sigmoid()
)
def forward(self, inputs):
return self.weights(inputs)
これでWeight Clipping前後で係数の値を比較します。
if __name__ == "__main__":
model = SomeModel()
print("Initial paramters")
for p in model.parameters():
print(p)
# Weight clippng
for p in model.parameters():
p.data.clamp_(-0.1, 0.1)
print("Weight clipped")
for p in model.parameters():
print(p)
「p.data.clamp_」でやっていることがWeight Clippingです。
結果
±0.1でクリップしてみました。
Initial paramters
Parameter containing:
tensor([[-0.0665, 0.3825, 0.4054],
[-0.3007, -0.4494, -0.4472],
[ 0.2556, -0.0655, 0.0282],
[ 0.1952, -0.2858, -0.5615],
[-0.5459, 0.5111, -0.0154],
[-0.3570, -0.3204, -0.5725],
[ 0.2857, -0.2683, -0.0341],
[-0.4374, -0.3827, 0.1591],
[-0.4890, 0.5398, 0.4199],
[-0.2990, -0.4984, 0.5026]], requires_grad=True)
Parameter containing:
tensor([ 0.3708, 0.5396, -0.0534, 0.3967, 0.2119, -0.2312, -0.4485, 0.2014,
0.5071, -0.1236], requires_grad=True)
Parameter containing:
tensor([0.3622, 0.8353, 0.6523, 0.5045, 0.7205, 0.0731, 0.2502, 0.2064, 0.7109,
0.4635], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)
Weight clipped
Parameter containing:
tensor([[-0.0665, 0.1000, 0.1000],
[-0.1000, -0.1000, -0.1000],
[ 0.1000, -0.0655, 0.0282],
[ 0.1000, -0.1000, -0.1000],
[-0.1000, 0.1000, -0.0154],
[-0.1000, -0.1000, -0.1000],
[ 0.1000, -0.1000, -0.0341],
[-0.1000, -0.1000, 0.1000],
[-0.1000, 0.1000, 0.1000],
[-0.1000, -0.1000, 0.1000]], requires_grad=True)
Parameter containing:
tensor([ 0.1000, 0.1000, -0.0534, 0.1000, 0.1000, -0.1000, -0.1000, 0.1000,
0.1000, -0.1000], requires_grad=True)
Parameter containing:
tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0731, 0.1000, 0.1000, 0.1000,
0.1000], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)
確かに±0.1の範囲外の値は頭打ちされているのがわかります。これで想定された実装になりました。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー