こしあん
2023-06-03

DDPMで画像生成してみた


3.5k{icon} {views}


現在使われている拡散モデルの最も基本的なモデルであるDDPM(Denoising Diffusion Probabilistic Models)を使って画像生成を試します。スクラッチからの訓練なので、Stable Diffusionのようにはうまくはいかないですが、拡散モデルの基本的なパターンは体感できるでしょう。

はじめに

現在使われている拡散モデルの最も基本的なモデルであるDDPM(Denoising Diffusion Probabilistic Models )を使って画像生成してみました。

HuggingFaceのHugging Face Diffusion Models Courseの実装のほぼコピーです。

対象データ

生成データは以下の3種類を試してみました

  • DTD
  • CIFAR-10
  • KMNIST

コード

CIFAR-10の場合です

import torch
import torchvision
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel
import os

def main():
    # データセットの読み込み
    train_data = torchvision.datasets.CIFAR10(
        root="./data_cifar", train=True, download=True,
        transform=transforms.Compose([
            transforms.CenterCrop(32),
            transforms.ToTensor()
        ]))
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=4)

    device = "cuda:1"
    torch.random.manual_seed(1234)
    torch.cuda.manual_seed(1234)
    validation_seed = torch.randn(16, 3, 32, 32).to(device)

    # https://github.com/huggingface/diffusion-models-class
    # モデルの定義
    model = UNet2DModel(
        sample_size=32,  # the target image resolution
        in_channels=3,  # the number of input channels, 3 for RGB images
        out_channels=3,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",  # a regular ResNet upsampling block
        ),
    )
    model.to(device)

    # 損失関数
    criterion = torch.nn.MSELoss()

    # 最適化手法
    optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

    # ノイズスケジューラー
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")

    # 学習ループ
    for epoch in range(100):
        running_loss = 0.0
        for i, (clean_images, _) in enumerate(train_loader):
            clean_images = clean_images.to(device)
            noise = torch.randn(clean_images.shape).to(device)

            # Random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.num_train_timesteps, 
                (noise.shape[0],), device=device)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            # Get noise prediction
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

            # Loss
            loss = criterion(noise_pred, noise)
            optimizer.zero_grad()
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        epoch_train_loss = running_loss / len(train_loader)
        print(f"Epoch = {epoch} | loss = {epoch_train_loss:.4f}")

        # サンプリング
        with torch.no_grad():
            sample = validation_seed.clone()
            for i, t in enumerate(noise_scheduler.timesteps):
                residual = model(sample, t).sample
                sample = noise_scheduler.step(residual, t, sample).prev_sample
            os.makedirs("sampling_cifar", exist_ok=True)
            torchvision.utils.save_image(sample, f"sampling_cifar/{epoch:03}.jpg", quality=90, nrow=4)

if __name__ == "__main__":
    main()

学習ループ内では、ランダムなタイムステップに対して、元画像(clean_images)からノイズへの拡散過程を行い、サンプルの中ではその逆をやって画像を生成しています。

U-Net1個入れるだけ、EncoderもDecoderも何もない非常に単純なモデルです。Diffuserだとよしなにやってくれるので簡単ですね。

結果

すべて100エポック回した例です

DTD

46エポック目

99エポック目

もともとこれはテクスチャのデータセットです。最初はテクスチャっぽいのが出てきたのですが、学習が進むについれてただのベタ塗りになってしまいました。何も条件付けしていないこんな単純な設定で、スクラッチ訓練はさすがに無理があったのでしょう。

CIFAR-10

100エポック目

Unconditionalにしてはなんかそれっぽい感じになってます。GANも割りとこんな感じでした。

KMNIST

52エポック目

100エポック目

さすがにMNISTシリーズぐらい簡単なデータになると、明らかに字っぽいのができています。

うまくいかなかったこと

  • Trainの前処理にRandomResizeCropを入れる
  • RandomHorizontalFlipを入れる

本来これはうまくいくのでしょうが、スクラッチから訓練する場合は収束が遅くなって外しました。これを外したらそれっぽい出力になりました。

終わりに

DDPMもDiffusersがサクッとできて便利!



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です