こしあん
2019-06-27

pix2pixを1から実装して白黒画像をカラー化してみた(PyTorch)


20.4k{icon} {views}


pix2pixによる白黒画像のカラー化を1から実装します。PyTorchで行います。かなり自然な色付けができました。pix2pixはGANの中でも理論が単純なのにくわえ、学習も比較的安定しているので結構おすすめです。

はじめに

PyTorchでDCGANができたので、今回はpix2pixをやります。今回は白黒画像のカラー化というよくありがちな例をやってみます。

あとで理論的な解説をしますが、やっていることは上図のとおりです。pix2pixはGANの一種です。Generatorに白黒画像を入れ偽のカラー画像を作り、Discriminatorに本物をカラー画像を入れ、min-maxゲームで互いを訓練していきます。

GANというとモード崩壊や勾配消失が怖いというイメージがありますが、pix2pixは訓練がかなり安定しているので応用が十分期待できるでしょう。

もとの論文:
P. Isola, J. Zhu, T. Zhou, A. A. Efros. Image-to-Image Translation with Conditional Adversarial Networks. CVPR, 2017.
https://arxiv.org/abs/1611.07004

Conditional-GAN

論文読んでいると、pix2pixはConditional GAN(cGAN)の一種として書かれています。DCGANは完全にノイズから画像を生成するのに対し、conditional GANは例えばMNISTでは、どの数字かというラベルの情報+ノイズで生成します。「pix2pixの場合はノイズもラベルもなかったはずだよな。なんでConditional GANなんだろう?」と思いましたが、白黒画像が「条件」となっているわけなんですね。DCGANのノイズも条件なんですが、ノイズが条件の条件付き確率ってノイズ項が消せちゃいますので。

pix2pixの損失関数

pix2pixはGANの一種なので、DCGANと損失関数が似ています。DCGANと異なる点は、Gにピクセル単位のL1損失を入れているということです。

こちらは普通のGANの損失関数です。

こちらがpix2pixの損失関数です。

Generatorの部分だけ変わっているので、Dは共通です。

λはL1損失と交差エントロピーの比率を決めるハイパーパラメータで、論文はλ=100で実験しています。高めのλを使うと元画像に近くなるので、これは直感的にはわかりやすいです。

実装上はGの損失関数は

  • DCGAN : BCEWithLogits(d_out_fake, ones)
  • pix2pix : BCEWithLogits(d_out_fake, ones) + lambda * L1Loss(fake, real)

のような形になります。ただ足すだけですね。

PatchGAN

これがPix2pixの面白いところですが、Dの判定が画像全体を本物・偽物で識別するのではなく、画像全体をパッチに分割してパッチ単位で本物・偽物を識別します。論文ではこれをPatchGANと読んでいます。

実装上は全く難しくなくて、YOLOを思い出せばいいのです。ある解像度の入力に対して、CNNの最後の層の解像度が(batch, 512, 3, 3)なら(NCHWフォーマット)とします。YOLOではCNNの位置不変性により、3×3の分割エリア/Sliding Windowに対して物があるかどうかを判定しているにすぎません。物があるかどうかを、「各Sliding Windowに対して本物かどうか」を判定すればPatch GANになります。

こういうイメージです。YOLOと一緒です。

実装上は、DCGANで本物/偽物と学習させたい場合のy_trueのテンソルが、

    ones = torch.ones(512)
    zeros = torch.zeros(512)
Python

なら、PatchGANの場合は、

    ones = torch.ones(512, 1, 3, 3)
    zeros = torch.zeros(512, 1, 3, 3)
Python

とすればいいだけです。これは最終層が3×3解像度で、バッチサイズが512のケースです。

色空間変換

この問題では、白黒画像の自動着色を行っているので、一般的なRGB色空間よりも、明るさと色差を分離したYCrCb色空間を経由したほうがより直接的な訓練ができるようになります。PyTorchでの色空間変換はこちらを参照してください。

PyTorchで行列(テンソル)積としてConv2dを使う
https://blog.shikoan.com/conv2d-matmul/

また、GやDで「YCrCb色空間」といったときは特に断りがなければYもCrCbも-1~1スケールです。本来のYCrCb色空間ではYは0~1、CrCbは-0.5~0.5のスケールです。詳細は末尾のコードを見てください。

G:YCrCb色空間, D:RGB色空間(成功例)

一番うまく行った(と思われる)例です。

序盤からかなり早く着色が進みます。これは色空間変換の効果です(特に、NonGANのただのU-Netだと顕著な傾向があります)。

画像データはもともRGB色空間ですが、Gに食わせる前にYCrCbへの色空間の変換を行います。そして、Gの入力をYCrCbのY・出力をCrCbとします。その後、Gの入力のYと出力のCrCbを足し合わせた画像を偽画像として、RGB色空間に戻します。Dの本物/偽物の判別はRGB色空間で行います。

DとGの損失推移です。Dの損失がゆるやかに減少し、Gの損失がゆるやかに増加しているのがわかります。Gの内訳を見ると、L1損失はほぼ変わらないのに対し、Cross Entropyだけが増えているのがわかります。これは通常のDCGANでも見られる現象です。

G:YCrCb色空間, D:YCrCb色空間(着色がおかしい?)

次にDの色空間をRGB色空間ではなく、YCrCb色空間とします。

エラー推移は先程と変わりませんが、着色がどうも違和感があります。CrCb(赤と青の色差)で見ているので、赤と青のモヤっとした部分が目立つようになります。Dの色空間はRGBのほうがよさそうです。

G:RGB色空間, D:RGB色空間+グレースケール画像(Dが強くなりすぎてうまくいない)

これはGもDもRGB色空間で行う例です。Gの入力はRGB色空間から生成したグレースケール画像、出力はRGB色空間のカラー画像としています。Dの入力はRGB色空間のカラー画像です。

ただし、Pix2pixの論文で読むと「Gの入力をDの入力に再度入れなさい」と書いているので、Gで使ったグレースケールの画像をDの入力に入れ、Dの入力は4チャンネルとなっています。ここは議論の余地があると思います。なぜなら、

  1. もとのRGB色空間から生成したグレースケールの画像は、RGB画像についての関数だから、DにRGB画像を入れた時点で既にグレースケールの成分は入っている
  2. Pix2pixの論文の定義にしたがってDにもRGB画像とグレースケールの両方を入れるべきだ

という2つの考え方があるからです。このケースでは(2)を選択しました。(1)も後ほど試してみます。

ただし、このケースでは訓練が安定しませんでした。Dのロスが低くなりすぎる(強くなりすぎる)と、以降Dのロスが振動して暴れだします。途中まで上手く行っていたが、どれも似たようなセピア調の着色になってしまう、着色が失われてしまいます。

G:RGB色空間, D:RGB色空間(まあまあうまくいく)

順番は前後してしまいましたが、これが最初に試したケースです。Dにグレースケールの画像を入れません(チャンネル数は3です)。先程の(1)の考えた方です。

結果はまあまあうまくいきました。GをYCrCbにしたケースとどっちが着色がいいかは好みがあるでしょう。

このケースでは、Gの入力であるグレースケールの画像を、Dに入れないほうがうまくいきました。

Gの色空間にともなう損失関数の変更

補足事項ですが、Gの色空間を変えた場合はGの損失関数を以下のように変えています。

  • Gの色空間がRGB : 交差エントロピー + 100 × RGB色空間でのL1損失
  • Gの色空間がYCrCb : 交差エントロピー + 75 × CrCb(YCrCb色空間からYを除外)でのL1損失 + 25 × RGB色空間でのL1損失

YCrCb色空間でRGBのL1損失を入れている理由は、Non-GANの場合に、YCrCbだけの損失関数だと、白や黒に近いような場所で極端な着色になってしまうからです。それを打ち消すためにRGB色空間でのL1損失を入れています。

まとめ

pix2pixを実装できました。結構手軽にできて学習も安定してるし、結果もそれっぽいのでコスパはかなり高そうです。

細かい話だと、白黒画像のカラー化の場合はDもGもRGB色空間か、GがYCrCbでDがRGB色空間がよさそうです。

コード

全てGistに上がってます。訓練時間は2080Tiが2枚で5時間程度でした。エポック間に謎の間が20~30秒あってここが短縮できればもっと高速になるはずです。

  • rgb_rgb.py : GもDも色空間がRGB。Dがグレースケールありの例。なしの場合はtorch.cat等を外せばいいだけなので、やっていることが理解できていればいじれるはずです。
  • ycrcb_rgb.py : GがYCrCbで、DがRGBのケースです。成功例として紹介したのがこちら。
  • ycrcb_ycrcb.py : GもDもYCrCbのケースです

import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import os
import pickle
import statistics
class ColorAndGray(object):
def __call__(self, img):
# ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
gray = img.convert("L")
return img, gray
# 複数の入力をtransformsに展開するラッパークラスを作る
class MultiInputWrapper(object):
def __init__(self, base_func):
self.base_func = base_func
def __call__(self, xs):
if isinstance(self.base_func, list):
return [f(x) for f, x in zip(self.base_func, xs)]
else:
return [self.base_func(x) for x in xs]
def load_datasets():
transform = transforms.Compose([
ColorAndGray(),
MultiInputWrapper(transforms.ToTensor()),
MultiInputWrapper([
transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
])
trainset = torchvision.datasets.STL10(root="./data",
split="unlabeled",
download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=512,
shuffle=True, num_workers=4, pin_memory=True)
return train_loader
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96
self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24
self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12
self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6
self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12
self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24
self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96
self.dec4 = nn.Sequential(
nn.Conv2d(32 + 32, 3, kernel_size=5, padding=2),
nn.Tanh()
)
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None):
layers = []
if pool_kernel is not None:
if pool_kernel > 0:
layers.append(nn.AvgPool2d(pool_kernel))
elif pool_kernel < 0:
layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel))
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
out = self.dec1(x4)
out = self.dec2(torch.cat([out, x3], dim=1))
out = self.dec3(torch.cat([out, x2], dim=1))
out = self.dec4(torch.cat([out, x1], dim=1))
return out
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_relu(4, 16, kernel_size=5, reps=1) # fake/true color + gray
self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4)
self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2)
self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2)
self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2)
self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2):
layers = []
for i in range(reps):
if i == 0 and pool_kernel is not None:
layers.append(nn.AvgPool2d(pool_kernel))
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch,
out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
return self.out_patch(out)
def train():
# モデル
device = "cuda"
torch.backends.cudnn.benchmark = True
model_G, model_D = Generator(), Discriminator()
model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D)
model_G, model_D = model_G.to(device), model_D.to(device)
params_G = torch.optim.Adam(model_G.parameters(),
lr=0.0002, betas=(0.5, 0.999))
params_D = torch.optim.Adam(model_D.parameters(),
lr=0.0002, betas=(0.5, 0.999))
# ロスを計算するためのラベル変数 (PatchGAN)
ones = torch.ones(512, 1, 3, 3).to(device)
zeros = torch.zeros(512, 1, 3, 3).to(device)
# 損失関数
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()
# エラー推移
result = {}
result["log_loss_G_sum"] = []
result["log_loss_G_bce"] = []
result["log_loss_G_mae"] = []
result["log_loss_D"] = []
# 訓練
dataset = load_datasets()
for i in range(200):
log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], []
for (real_color, input_gray), _ in tqdm(dataset):
batch_len = len(real_color)
real_color, input_gray = real_color.to(device), input_gray.to(device)
# Gの訓練
# 偽のカラー画像を作成
fake_color = model_G(input_gray)
# 偽画像を一時保存
fake_color_tensor = fake_color.detach()
# 偽画像を本物と騙せるようにロスを計算
LAMBD = 100.0 # BCEとMAEの係数
out = model_D(torch.cat([fake_color, input_gray], dim=1))
loss_G_bce = bce_loss(out, ones[:batch_len])
loss_G_mae = LAMBD * mae_loss(fake_color, real_color)
loss_G_sum = loss_G_bce + loss_G_mae
log_loss_G_bce.append(loss_G_bce.item())
log_loss_G_mae.append(loss_G_mae.item())
log_loss_G_sum.append(loss_G_sum.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_G_sum.backward()
params_G.step()
# Discriminatoの訓練
# 本物のカラー画像を本物と識別できるようにロスを計算
real_out = model_D(torch.cat([real_color, input_gray], dim=1))
loss_D_real = bce_loss(real_out, ones[:batch_len])
# 偽の画像の偽と識別できるようにロスを計算
fake_out = model_D(torch.cat([fake_color_tensor, input_gray], dim=1))
loss_D_fake = bce_loss(fake_out, zeros[:batch_len])
# 実画像と偽画像のロスを合計
loss_D = loss_D_real + loss_D_fake
log_loss_D.append(loss_D.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_D.backward()
params_D.step()
result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum))
result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce))
result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae))
result["log_loss_D"].append(statistics.mean(log_loss_D))
print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " +
f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " +
f"log_loss_D = {result['log_loss_D'][-1]}")
# 画像を保存
if not os.path.exists("stl_color"):
os.mkdir("stl_color")
# 生成画像を保存
torchvision.utils.save_image(fake_color_tensor[:min(batch_len, 100)],
f"stl_color/fake_epoch_{i:03}.png",
range=(-1.0,1.0), normalize=True)
torchvision.utils.save_image(real_color[:min(batch_len, 100)],
f"stl_color/real_epoch_{i:03}.png",
range=(-1.0, 1.0), normalize=True)
# モデルの保存
if not os.path.exists("stl_color/models"):
os.mkdir("stl_color/models")
if i % 10 == 0 or i == 199:
torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch")
torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch")
# ログの保存
with open("stl_color/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
train()
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import os
import pickle
import statistics
def load_datasets():
transform = transforms.Compose([
transforms.ToTensor(),
])
trainset = torchvision.datasets.STL10(root="./data",
split="unlabeled",
download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=512,
shuffle=True, num_workers=4, pin_memory=True)
return train_loader
# 色空間変換用の定数
RGB2YCrCb = np.array([[0.299, 0.587, 0.114],
[0.5, -0.418688, -0.081312],
[-0.168736, -0.331264, 0.5]], np.float32)
YCrCb2RGB = np.array([[1, 1.402, 0],
[1, -0.714136, -0.344136],
[1, 0, 1.772]], np.float32)
RGB2YCrCb = torch.as_tensor(RGB2YCrCb.reshape(3, 3, 1, 1)).to("cuda")
YCrCb2RGB = torch.as_tensor(YCrCb2RGB.reshape(3, 3, 1, 1)).to("cuda")
def preprocess_generator(rgb_tensor):
x = nn.functional.conv2d(rgb_tensor, RGB2YCrCb) # Yが0 - 1, CbCrが-0.5 - 0.5
x *= 2.0 # CbCrは-1 - 1になったのでOK、Yが0-2
x[:, 0,:,:] -= 1.0 # Yを-1 - 1にする
return x
def deprocess_generator(ycrcb_tensor):
# inputが全て-1 - 1のスケールなので本来のスケールに直す
x = ycrcb_tensor / 2.0 # 全て-0.5-0.5, CbCrはOK
x[:, 0,:,:] += 0.5 # Yのスケールを0-1にする
# RGBに変換 (0-1)
return nn.functional.conv2d(x, YCrCb2RGB)
class Generator(nn.Module):
# input : 1x96x96 の Y [-1, 1] 本当は[0, 1]
# output : 2x96x96 の CrCb [-1, 1] 本当は[-0.5, 0.5]
def __init__(self):
super().__init__()
self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96
self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24
self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12
self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6
self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12
self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24
self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96
self.dec4 = nn.Sequential(
nn.Conv2d(32 + 32, 2, kernel_size=5, padding=2),
nn.Tanh()
)
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None):
layers = []
if pool_kernel is not None:
if pool_kernel > 0:
layers.append(nn.AvgPool2d(pool_kernel))
elif pool_kernel < 0:
layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel))
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
out = self.dec1(x4)
out = self.dec2(torch.cat([out, x3], dim=1))
out = self.dec3(torch.cat([out, x2], dim=1))
out = self.dec4(torch.cat([out, x1], dim=1))
return out
class Discriminator(nn.Module):
# Inputの色空間はYCrCb→RGBにする
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_relu(3, 16, kernel_size=5, reps=1) # RGB
self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4)
self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2)
self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2)
self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2)
self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2):
layers = []
for i in range(reps):
if i == 0 and pool_kernel is not None:
layers.append(nn.AvgPool2d(pool_kernel))
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch,
out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
return self.out_patch(out)
def train():
# モデル
device = "cuda"
torch.backends.cudnn.benchmark = True
model_G, model_D = Generator(), Discriminator()
model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D)
model_G, model_D = model_G.to(device), model_D.to(device)
params_G = torch.optim.Adam(model_G.parameters(),
lr=0.0002, betas=(0.5, 0.999))
params_D = torch.optim.Adam(model_D.parameters(),
lr=0.0002, betas=(0.5, 0.999))
# ロスを計算するためのラベル変数 (PatchGAN)
ones = torch.ones(512, 1, 3, 3).to(device)
zeros = torch.zeros(512, 1, 3, 3).to(device)
# 損失関数
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()
# エラー推移
result = {}
result["log_loss_G_sum"] = []
result["log_loss_G_bce"] = []
result["log_loss_G_mae"] = []
result["log_loss_D"] = []
# 訓練
dataset = load_datasets()
for i in range(200):
log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], []
for real_rgb, _ in tqdm(dataset):
batch_len = len(real_rgb)
real_rgb = real_rgb.to(device)
real_ycrcb = preprocess_generator(real_rgb)
# Gの訓練
# 偽のカラー画像を作成
fake_crcb = model_G(real_ycrcb[:,:1,:,:])
fake_ycrcb = torch.cat([real_ycrcb[:,:1,:,:], fake_crcb], dim=1)
fake_rgb = deprocess_generator(fake_ycrcb)
# 偽画像を一時保存
fake_rgb_tensor = fake_rgb.detach()
# 偽画像を本物と騙せるようにロスを計算
out = model_D(fake_rgb)
loss_G_bce = bce_loss(out, ones[:batch_len])
loss_G_mae = 75 * mae_loss(fake_crcb, real_ycrcb[:, 1:,:,:]) + 25 * mae_loss(fake_rgb, real_rgb)
loss_G_sum = loss_G_bce + loss_G_mae
log_loss_G_bce.append(loss_G_bce.item())
log_loss_G_mae.append(loss_G_mae.item())
log_loss_G_sum.append(loss_G_sum.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_G_sum.backward()
params_G.step()
# Discriminatoの訓練
# 本物のカラー画像を本物と識別できるようにロスを計算
real_out = model_D(real_rgb)
loss_D_real = bce_loss(real_out, ones[:batch_len])
# 偽の画像の偽と識別できるようにロスを計算
fake_out = model_D(fake_rgb_tensor)
loss_D_fake = bce_loss(fake_out, zeros[:batch_len])
# 実画像と偽画像のロスを合計
loss_D = loss_D_real + loss_D_fake
log_loss_D.append(loss_D.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_D.backward()
params_D.step()
result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum))
result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce))
result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae))
result["log_loss_D"].append(statistics.mean(log_loss_D))
print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " +
f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " +
f"log_loss_D = {result['log_loss_D'][-1]}")
# 画像を保存
if not os.path.exists("stl_color"):
os.mkdir("stl_color")
# 生成画像を保存
torchvision.utils.save_image(fake_rgb_tensor[:min(batch_len, 100)],
f"stl_color/fake_epoch_{i:03}.png")
torchvision.utils.save_image(real_rgb[:min(batch_len, 100)],
f"stl_color/real_epoch_{i:03}.png")
# モデルの保存
if not os.path.exists("stl_color/models"):
os.mkdir("stl_color/models")
if i % 10 == 0 or i == 199:
torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch")
torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch")
# ログの保存
with open("stl_color/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
train()
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import os
import pickle
import statistics
def load_datasets():
transform = transforms.Compose([
transforms.ToTensor(),
])
trainset = torchvision.datasets.STL10(root="./data",
split="unlabeled",
download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=512,
shuffle=True, num_workers=4, pin_memory=True)
return train_loader
# 色空間変換用の定数
RGB2YCrCb = np.array([[0.299, 0.587, 0.114],
[0.5, -0.418688, -0.081312],
[-0.168736, -0.331264, 0.5]], np.float32)
YCrCb2RGB = np.array([[1, 1.402, 0],
[1, -0.714136, -0.344136],
[1, 0, 1.772]], np.float32)
RGB2YCrCb = torch.as_tensor(RGB2YCrCb.reshape(3, 3, 1, 1)).to("cuda")
YCrCb2RGB = torch.as_tensor(YCrCb2RGB.reshape(3, 3, 1, 1)).to("cuda")
def preprocess_generator(rgb_tensor):
x = nn.functional.conv2d(rgb_tensor, RGB2YCrCb) # Yが0 - 1, CbCrが-0.5 - 0.5
x *= 2.0 # CbCrは-1 - 1になったのでOK、Yが0-2
x[:, 0,:,:] -= 1.0 # Yを-1 - 1にする
return x
def deprocess_generator(ycrcb_tensor):
# inputが全て-1 - 1のスケールなので本来のスケールに直す
x = ycrcb_tensor / 2.0 # 全て-0.5-0.5, CbCrはOK
x[:, 0,:,:] += 0.5 # Yのスケールを0-1にする
# RGBに変換 (0-1)
return nn.functional.conv2d(x, YCrCb2RGB)
class Generator(nn.Module):
# input : 1x96x96 の Y [-1, 1] 本当は[0, 1]
# output : 2x96x96 の CrCb [-1, 1] 本当は[-0.5, 0.5]
def __init__(self):
super().__init__()
self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96
self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24
self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12
self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6
self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12
self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24
self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96
self.dec4 = nn.Sequential(
nn.Conv2d(32 + 32, 2, kernel_size=5, padding=2),
nn.Tanh()
)
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None):
layers = []
if pool_kernel is not None:
if pool_kernel > 0:
layers.append(nn.AvgPool2d(pool_kernel))
elif pool_kernel < 0:
layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel))
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
out = self.dec1(x4)
out = self.dec2(torch.cat([out, x3], dim=1))
out = self.dec3(torch.cat([out, x2], dim=1))
out = self.dec4(torch.cat([out, x1], dim=1))
return out
class Discriminator(nn.Module):
# Inputの色空間はYCrCb
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_relu(3, 16, kernel_size=5, reps=1) # YCrCb
self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4)
self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2)
self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2)
self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2)
self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2):
layers = []
for i in range(reps):
if i == 0 and pool_kernel is not None:
layers.append(nn.AvgPool2d(pool_kernel))
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch,
out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
return self.out_patch(out)
def train():
# モデル
device = "cuda"
torch.backends.cudnn.benchmark = True
model_G, model_D = Generator(), Discriminator()
model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D)
model_G, model_D = model_G.to(device), model_D.to(device)
params_G = torch.optim.Adam(model_G.parameters(),
lr=0.0002, betas=(0.5, 0.999))
params_D = torch.optim.Adam(model_D.parameters(),
lr=0.0002, betas=(0.5, 0.999))
# ロスを計算するためのラベル変数 (PatchGAN)
ones = torch.ones(512, 1, 3, 3).to(device)
zeros = torch.zeros(512, 1, 3, 3).to(device)
# 損失関数
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()
# エラー推移
result = {}
result["log_loss_G_sum"] = []
result["log_loss_G_bce"] = []
result["log_loss_G_mae"] = []
result["log_loss_D"] = []
# 訓練
dataset = load_datasets()
for i in range(200):
log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], []
for real_rgb, _ in tqdm(dataset):
batch_len = len(real_rgb)
real_rgb = real_rgb.to(device)
real_ycrcb = preprocess_generator(real_rgb)
# Gの訓練
# 偽のカラー画像を作成
fake_crcb = model_G(real_ycrcb[:,:1,:,:])
fake_ycrcb = torch.cat([real_ycrcb[:,:1,:,:], fake_crcb], dim=1)
fake_rgb = deprocess_generator(fake_ycrcb)
# 偽画像を一時保存
fake_ycrcb_tensor = fake_ycrcb.detach()
# 偽画像を本物と騙せるようにロスを計算
out = model_D(fake_ycrcb)
loss_G_bce = bce_loss(out, ones[:batch_len])
loss_G_mae = 75 * mae_loss(fake_crcb, real_ycrcb[:, 1:,:,:]) + 25 * mae_loss(fake_rgb, real_rgb)
loss_G_sum = loss_G_bce + loss_G_mae
log_loss_G_bce.append(loss_G_bce.item())
log_loss_G_mae.append(loss_G_mae.item())
log_loss_G_sum.append(loss_G_sum.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_G_sum.backward()
params_G.step()
# Discriminatoの訓練
# 本物のカラー画像を本物と識別できるようにロスを計算
real_out = model_D(real_ycrcb)
loss_D_real = bce_loss(real_out, ones[:batch_len])
# 偽の画像の偽と識別できるようにロスを計算
fake_out = model_D(fake_ycrcb_tensor)
loss_D_fake = bce_loss(fake_out, zeros[:batch_len])
# 実画像と偽画像のロスを合計
loss_D = loss_D_real + loss_D_fake
log_loss_D.append(loss_D.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_D.backward()
params_D.step()
result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum))
result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce))
result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae))
result["log_loss_D"].append(statistics.mean(log_loss_D))
print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " +
f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " +
f"log_loss_D = {result['log_loss_D'][-1]}")
# 画像を保存
if not os.path.exists("stl_color"):
os.mkdir("stl_color")
# 生成画像を保存
fake_rgb_tensor = deprocess_generator(fake_ycrcb_tensor)
torchvision.utils.save_image(fake_rgb_tensor[:min(batch_len, 100)],
f"stl_color/fake_epoch_{i:03}.png")
torchvision.utils.save_image(real_rgb[:min(batch_len, 100)],
f"stl_color/real_epoch_{i:03}.png")
# モデルの保存
if not os.path.exists("stl_color/models"):
os.mkdir("stl_color/models")
if i % 10 == 0 or i == 199:
torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch")
torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch")
# ログの保存
with open("stl_color/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
train()



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

  
Terraformで学ぶAWS(1):サーバーレスから始める再利用可能なインフラストラクチャ
  
AIアートの新時代2:Stable Diffusionの課題と動画生成の新潮流
  
コーディング侍:Pythonで学ぶ機械学習ソフトウェア開発の極意
  
AIアートの新時代:CLIPとStable Diffusionを活用した画像生成技術とその応用

One Comment

Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です