こしあん
2023-06-03

Kohya版のLoRAをDockerで訓練する


3.7k{icon} {views}


LoRAのデファクトスタンダードとなりつつあるKohya版のLoRA学習スクリプトをDockerで動かしてみました。Diffusersとの連携も可能で、DiffusersオフィシャルのLoRAよりも高品質な学習が期待できます。PyTorch2の環境でも動かすことができました。

はじめに

Civitaiなどで配布されているLoRAでデファクトスタンダード形式となりつつあるKohya版のLoRAをWebUIなしで訓練します。ベースのモデルはDiffusers形式でもOKで、Diffusersの連携もできそうです。

LoRAの訓練コードの違い

この世界には画像生成モデルのLoRA訓練コードが多数溢れています。いくつか例を挙げますと、

PEFTはDiffusersの古い版だったので除外してもいいと思います。DiffusersかKohya版かのどちらかになると思いますが、結論からいうとKohya版のほうが高度な訓練しています。わかりやすい点をいくつか挙げると、

  • Diffusers版はData Augmentationが簡単なフリップ程度であるのに対し、Kohya版はData Augmentationの専用のライブラリのAlbumentationsを使っている
  • bitsandbytesというライブラリを使うことで、AdamWや最近出たLionなどなどの発展的なオプティマイザの8bit訓練ができる
  • 公式解説が日本語記事
  • safetensorsで吐き出しできる

などです。ただデメリットとして、使っているライブラリが多い分依存関係がややこしいというのがあります。そこを整理してDockerfileにしたというのが今回の記事の目的です。

ディレクトリ構成

  • docker
    • Dockerfile
    • requirements.txt
  • workdir
    • data
    • lora_output
    • dataset.toml

このような構成にします。dockerフォルダにはDockerビルドに必要なもの、workdirはDockerイメージにマウントさせる作業ディレクトリです。データ類もこちらに入れます。

Dockerイメージ

公式の環境構築だとPyTorch1系統で書いていますが、PyTorch2にしてもLoRAの訓練程度だと動きました。

Dockerfile

FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04

RUN apt-get update
ENV TZ=Asia/Tokyo
ENV LANG=en_US.UTF-8
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
RUN apt-get install -yq --no-install-recommends python3-pip \
        python3-dev \
        wget \
        git  \
        libopencv-dev \
        tzdata && apt-get upgrade -y && apt-get clean

RUN ln -s /usr/bin/python3 /usr/bin/python

COPY requirements.txt .
RUN pip install -U pip &&\
  pip install --no-cache-dir torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
RUN pip install --upgrade --no-cache-dir -r requirements.txt
RUN git clone https://github.com/kohya-ss/sd-scripts.git -b v0.6.4 # ここのタグは入れ替えてね
Docker

最後のGit cloneはブランチを指定していて、バージョン0.6.4のものをCloneするようにしています。ここが不要でしたら-b以降をとってください。

ベースのイメージはPyTorch2.0.0としました。timmの関係上、現在(2023/6/3)はPyTorch2.0.1は対応していませんでした。

requirements.txt

accelerate==0.15.0
xformers==0.0.19
transformers==4.26.0
ftfy==6.1.1
albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
diffusers[torch]==0.10.2
pytorch-lightning==1.9.0
bitsandbytes==0.39.0
tensorboard==2.10.1
safetensors==0.2.6
altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
# for BLIP captioning
requests==2.28.2
timm==0.9.2
fairscale==0.4.13
# for WD14 captioning
# tensorflow==2.10.1
huggingface-hub==0.13.3
Docker

requirements.txtのバージョンは公式と一部変えています。キャプション生成のためだけにTensorFlow入れたくなかったので、外しました。必要だったら入れてください。

普通にWSLなどでDockerをビルドすればOKです。カレントディレクトリは「docker」で

docker build -t kohya-ss .
Console

データセットの準備

東北ずん子のLoRAの学習データを例にとります。これは最初からキャプションデータが用意されているので簡単です。

https://drive.google.com/drive/folders/1NIZcBRvr5i8YfPsPYwvVMC7SoH-cWLIk

今回は「01_LoRA学習用データA氏提供版背景白」から「kiritan」を対象とします。

ダウンロードするとkiritanフォルダ直下にpngとtxtが

  • kiri(1).png
  • kiri(1).txt
  • kiri(2).png
  • kiri(2).txt

のように入っています。これを「data」フォルダ直下に移します。

次にKohya版LoRAの訓練に必要なtomlを記述します。これはyamlの拡張版で実際に記述するものはただのテキストです。

[general]
enable_bucket = true                        # Aspect Ratio Bucketingを使うか否か

[[datasets]]
resolution = 512                            # 学習解像度
batch_size = 1                              # バッチサイズ

  [[datasets.subsets]]
  image_dir = '/workdir/data/kiritan'                     # 学習用画像を入れたフォルダを指定
  caption_extension = '.txt'            # キャプションファイルの拡張子 .txt を使う場合には書き換える
  num_repeats = 10                          # 学習用画像の繰り返し回数
Console

このように記述しました。「image_dir」の部分だけ注意が必要で、Dockerコンテナ内でのパスになります。このケースでは、ローカルで「workdir」というフォルダを、「/workdir」というパスにマウントさせるので、「/workdir/…」と書きます。

コンテナの起動

ビルドが終わったらコンテナを起動します。カレントディレクトリを「docker」の一個上にして、

docker run --rm -it --gpus all -v $(pwd)/workdir:/workdir kohya-ss
Console

次にコンテナ内に移動するので、スクリプトのパスに移動します。

cd sd-scripts
Console

accelerateの設定をします。面倒だったらローカルで定義したyamlをマウントしてもいいかもしれません。

------------------------------------------------------------------------------------------------------------------------In which compute environment are you running?
This machine
------------------------------------------------------------------------------------------------------------------------Which type of machine are you using?
No distributed training
Do you want to run your training on CPU only (even if a GPU is available)? [yes/NO]:no
Do you wish to optimize your script with torch dynamo?[yes/NO]:no
Do you want to use DeepSpeed? [yes/NO]: no
What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:1
------------------------------------------------------------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)?
fp16
Console

このようにしました。GPU IDの指定はallでもいいのですが、2個GPUがあって2個目のみ使いたかったので「1」と指定しました。

LoRAの訓練コマンド

accelerate launch --num_cpu_threads_per_process 1 train_network.py \
    --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
    --dataset_config="/workdir/kiritan_training.toml" \
    --output_dir="/workdir/lora_output" \
    --output_name="sd15_lora" \
    --save_model_as=safetensors \
    --prior_loss_weight=1.0 \
    --max_train_epochs=10 \
    --learning_rate=1e-4 \
    --optimizer_type="Lion8bit" \
    --xformers \
    --mixed_precision="fp16" \
    --cache_latents \
    --gradient_checkpointing \
    --save_every_n_epochs=2 \
    --seed=0 \
    --network_module=networks.lora
Console

このようにします。改行するときはバックスラッシュをお忘れなく!(これを忘れて設定が読み込まれずに2時間ぐらいハマりました)

オプティマイザは公式例にあるような「AdamW8bit」だとエラーになってしまったので、「Lion8bit」にしました。こっちのほうが高度なオプティマイザなのでこれでいいかなと思います。AdamWを使いたい場合は8bitやめて「AdamW」だと動きました。

ここがKohya版の強いところなのですが、「pretrained_model_name_or_path」はDiffusersのモデルIDを指定しても動きます。しかもそれはローカルにある必要はなく、Hugging Faceからよしなにダウンロードしてくれます。訓練コードのモデルの読み込みがDiffusersなので、このへんは得意です。

2080Tiで40分ぐらいで終わりました。もっとエポック数減らしてもいいかなと思います。

このようによく見慣れたLoRAができあがりました。Dockerでディレクトリごとマウントしているので、Dockerで吐き出したファイルを、ローカルからも使えます(永続化されます)。

推論してみる

safetensorsのLoRAと扱いは同じになりますので、オンザフライでマージさせます。やり方はこの記事参照

from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
import torch
import matplotlib.pyplot as plt
from safetensors.torch import load_file

def load_safetensors_lora(pipeline,
                          checkpoint_path,
                          LORA_PREFIX_UNET = "lora_unet",
                          LORA_PREFIX_TEXT_ENCODER = "lora_te",
                          alpha = 0.75):
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path)

    visited = []
    # directly update weight in diffusers model
    for key in state_dict:
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        # as we have set the alpha beforehand, so just skip
        if ".alpha" in key or key in visited:
            continue

        if "text" in key:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        pair_keys = []
        if "lora_down" in key:
            pair_keys.append(key.replace("lora_down", "lora_up"))
            pair_keys.append(key)
        else:
            pair_keys.append(key)
            pair_keys.append(key.replace("lora_up", "lora_down"))

        # update weight
        if len(state_dict[pair_keys[0]].shape) == 4:
            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_up = state_dict[pair_keys[0]].to(torch.float32)
            weight_down = state_dict[pair_keys[1]].to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)

        # update visited list
        for item in pair_keys:
            visited.append(item)

    return pipeline

def main():
    lora_path = "workdir/lora_output/sd15_lora.safetensors"
    device = "cuda:0"
    model_id = "runwayml/stable-diffusion-v1-5"
    pipe = StableDiffusionPipeline.from_pretrained(model_id, 
                                                   torch_dtype=torch.float16, safety_checker=None)    
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = load_safetensors_lora(pipe,
                                 lora_path,
                                 alpha=0.2)
    pipe.to(device, torch.float16)

    prompt = "kiritan, look at viewer, 1girl, best quality"
    negative_prompt = "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
    generator = torch.Generator(device).manual_seed(1234)
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=generator,
                 num_inference_steps=30, num_images_per_prompt=9).images

    fig = plt.figure(figsize=(12, 12))
    for i in range(9):
        ax = fig.add_subplot(3, 3, i+1)
        ax.imshow(image[i])
        ax.axis("off")

    plt.show()

if __name__ == "__main__":
    main()
Python

αをかなり落とさないといけなかったので、もしかしたらかなり過学習してしまったかもしれません。エポック数をもっと減らしても良かったかもしれませんね。

推論結果

推論時のプロンプトは「kiritan, look at viewer, 1girl, best quality」です。

SD1.5で訓練し、SD1.5で推論

α=0.2

元がアニメモデルでないのでこんな感じかなという具合です

SD1.5で訓練し、SomethingV2で推論

α=0.2

アニメモデルのSomethingV2_2で推論した結果です。結果はかなりうまくいっていて、SD1.5で訓練して別モデルに移植してもうまくいっているのが確認できます。

SomethingV2で訓練し、SD1.5で推論

α=0.3

こちらはLoRAの学習をアニメモデル(SomethingV2)で行い、SD1.5で推論した場合です。アニメ→きりたんだけでは、アニメ方向へのタスクの学習が行われず、LoRAの適用結果は実写のままでした。αを大きくすると画像が破綻しました。

SomethingV2で訓練し、SomethingV2で推論

こちらはLoRAの学習モデルと推論モデルが一致した結果です。これは分布シフトがないので一番きれいな結果になります。ここまでできるKohya版のLoRAすごいですね。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

  
Terraformで学ぶAWS(1):サーバーレスから始める再利用可能なインフラストラクチャ
  
AIアートの新時代2:Stable Diffusionの課題と動画生成の新潮流
  
コーディング侍:Pythonで学ぶ機械学習ソフトウェア開発の極意
  
AIアートの新時代:CLIPとStable Diffusionを活用した画像生成技術とその応用

Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です