DDPMで画像生成してみた
現在使われている拡散モデルの最も基本的なモデルであるDDPM(Denoising Diffusion Probabilistic Models)を使って画像生成を試します。スクラッチからの訓練なので、Stable Diffusionのようにはうまくはいかないですが、拡散モデルの基本的なパターンは体感できるでしょう。
目次
はじめに
現在使われている拡散モデルの最も基本的なモデルであるDDPM(Denoising Diffusion Probabilistic Models )を使って画像生成してみました。
HuggingFaceのHugging Face Diffusion Models Courseの実装のほぼコピーです。
対象データ
生成データは以下の3種類を試してみました
- DTD
- CIFAR-10
- KMNIST
コード
CIFAR-10の場合です
import torch
import torchvision
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel
import os
def main():
# データセットの読み込み
train_data = torchvision.datasets.CIFAR10(
root="./data_cifar", train=True, download=True,
transform=transforms.Compose([
transforms.CenterCrop(32),
transforms.ToTensor()
]))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=4)
device = "cuda:1"
torch.random.manual_seed(1234)
torch.cuda.manual_seed(1234)
validation_seed = torch.randn(16, 3, 32, 32).to(device)
# https://github.com/huggingface/diffusion-models-class
# モデルの定義
model = UNet2DModel(
sample_size=32, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(64, 128, 128, 256), # More channels -> more parameters
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D", # a regular ResNet upsampling block
),
)
model.to(device)
# 損失関数
criterion = torch.nn.MSELoss()
# 最適化手法
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)
# ノイズスケジューラー
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
# 学習ループ
for epoch in range(100):
running_loss = 0.0
for i, (clean_images, _) in enumerate(train_loader):
clean_images = clean_images.to(device)
noise = torch.randn(clean_images.shape).to(device)
# Random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps,
(noise.shape[0],), device=device)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
# Get noise prediction
noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
# Loss
loss = criterion(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
running_loss += loss.item()
optimizer.step()
epoch_train_loss = running_loss / len(train_loader)
print(f"Epoch = {epoch} | loss = {epoch_train_loss:.4f}")
# サンプリング
with torch.no_grad():
sample = validation_seed.clone()
for i, t in enumerate(noise_scheduler.timesteps):
residual = model(sample, t).sample
sample = noise_scheduler.step(residual, t, sample).prev_sample
os.makedirs("sampling_cifar", exist_ok=True)
torchvision.utils.save_image(sample, f"sampling_cifar/{epoch:03}.jpg", quality=90, nrow=4)
if __name__ == "__main__":
main()
学習ループ内では、ランダムなタイムステップに対して、元画像(clean_images)からノイズへの拡散過程を行い、サンプルの中ではその逆をやって画像を生成しています。
U-Net1個入れるだけ、EncoderもDecoderも何もない非常に単純なモデルです。Diffuserだとよしなにやってくれるので簡単ですね。
結果
すべて100エポック回した例です
DTD
46エポック目
99エポック目
もともとこれはテクスチャのデータセットです。最初はテクスチャっぽいのが出てきたのですが、学習が進むについれてただのベタ塗りになってしまいました。何も条件付けしていないこんな単純な設定で、スクラッチ訓練はさすがに無理があったのでしょう。
CIFAR-10
100エポック目
Unconditionalにしてはなんかそれっぽい感じになってます。GANも割りとこんな感じでした。
KMNIST
52エポック目
100エポック目
さすがにMNISTシリーズぐらい簡単なデータになると、明らかに字っぽいのができています。
うまくいかなかったこと
- Trainの前処理にRandomResizeCropを入れる
- RandomHorizontalFlipを入れる
本来これはうまくいくのでしょうが、スクラッチから訓練する場合は収束が遅くなって外しました。これを外したらそれっぽい出力になりました。
終わりに
DDPMもDiffusersがサクッとできて便利!
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー