Kohya版のLoRAをDockerで訓練する
LoRAのデファクトスタンダードとなりつつあるKohya版のLoRA学習スクリプトをDockerで動かしてみました。Diffusersとの連携も可能で、DiffusersオフィシャルのLoRAよりも高品質な学習が期待できます。PyTorch2の環境でも動かすことができました。
目次
はじめに
Civitaiなどで配布されているLoRAでデファクトスタンダード形式となりつつあるKohya版のLoRAをWebUIなしで訓練します。ベースのモデルはDiffusers形式でもOKで、Diffusersの連携もできそうです。
LoRAの訓練コードの違い
この世界には画像生成モデルのLoRA訓練コードが多数溢れています。いくつか例を挙げますと、
- Hugging FaceのDiffusersライブラリに登録されている訓練コード
- dreamboothとtext_to_imageの2種類
- Hugging FaceのPEFTライブラリに登録されている訓練コード
- 基本はDiffusers版の古いバージョンが入っているだけです
- Kohya版のLoRA訓練コード
- WebUI界隈でデファクトスタンダードとなっているもの
PEFTはDiffusersの古い版だったので除外してもいいと思います。DiffusersかKohya版かのどちらかになると思いますが、結論からいうとKohya版のほうが高度な訓練しています。わかりやすい点をいくつか挙げると、
- Diffusers版はData Augmentationが簡単なフリップ程度であるのに対し、Kohya版はData Augmentationの専用のライブラリのAlbumentationsを使っている
- bitsandbytesというライブラリを使うことで、AdamWや最近出たLionなどなどの発展的なオプティマイザの8bit訓練ができる
- 公式解説が日本語記事
- safetensorsで吐き出しできる
などです。ただデメリットとして、使っているライブラリが多い分依存関係がややこしいというのがあります。そこを整理してDockerfileにしたというのが今回の記事の目的です。
ディレクトリ構成
- docker
- Dockerfile
- requirements.txt
- workdir
- data
- lora_output
- dataset.toml
このような構成にします。dockerフォルダにはDockerビルドに必要なもの、workdirはDockerイメージにマウントさせる作業ディレクトリです。データ類もこちらに入れます。
Dockerイメージ
公式の環境構築だとPyTorch1系統で書いていますが、PyTorch2にしてもLoRAの訓練程度だと動きました。
Dockerfile
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04
RUN apt-get update
ENV TZ=Asia/Tokyo
ENV LANG=en_US.UTF-8
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
RUN apt-get install -yq --no-install-recommends python3-pip \
python3-dev \
wget \
git \
libopencv-dev \
tzdata && apt-get upgrade -y && apt-get clean
RUN ln -s /usr/bin/python3 /usr/bin/python
COPY requirements.txt .
RUN pip install -U pip &&\
pip install --no-cache-dir torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
RUN pip install --upgrade --no-cache-dir -r requirements.txt
RUN git clone https://github.com/kohya-ss/sd-scripts.git -b v0.6.4 # ここのタグは入れ替えてね
最後のGit cloneはブランチを指定していて、バージョン0.6.4のものをCloneするようにしています。ここが不要でしたら-b以降をとってください。
ベースのイメージはPyTorch2.0.0としました。timmの関係上、現在(2023/6/3)はPyTorch2.0.1は対応していませんでした。
requirements.txt
accelerate==0.15.0
xformers==0.0.19
transformers==4.26.0
ftfy==6.1.1
albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
diffusers[torch]==0.10.2
pytorch-lightning==1.9.0
bitsandbytes==0.39.0
tensorboard==2.10.1
safetensors==0.2.6
altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
# for BLIP captioning
requests==2.28.2
timm==0.9.2
fairscale==0.4.13
# for WD14 captioning
# tensorflow==2.10.1
huggingface-hub==0.13.3
requirements.txtのバージョンは公式と一部変えています。キャプション生成のためだけにTensorFlow入れたくなかったので、外しました。必要だったら入れてください。
普通にWSLなどでDockerをビルドすればOKです。カレントディレクトリは「docker」で
docker build -t kohya-ss .
データセットの準備
東北ずん子のLoRAの学習データを例にとります。これは最初からキャプションデータが用意されているので簡単です。
https://drive.google.com/drive/folders/1NIZcBRvr5i8YfPsPYwvVMC7SoH-cWLIk
今回は「01_LoRA学習用データA氏提供版背景白」から「kiritan」を対象とします。
ダウンロードするとkiritanフォルダ直下にpngとtxtが
- kiri(1).png
- kiri(1).txt
- kiri(2).png
- kiri(2).txt
のように入っています。これを「data」フォルダ直下に移します。
次にKohya版LoRAの訓練に必要なtomlを記述します。これはyamlの拡張版で実際に記述するものはただのテキストです。
[general]
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
[[datasets]]
resolution = 512 # 学習解像度
batch_size = 1 # バッチサイズ
[[datasets.subsets]]
image_dir = '/workdir/data/kiritan' # 学習用画像を入れたフォルダを指定
caption_extension = '.txt' # キャプションファイルの拡張子 .txt を使う場合には書き換える
num_repeats = 10 # 学習用画像の繰り返し回数
このように記述しました。「image_dir」の部分だけ注意が必要で、Dockerコンテナ内でのパスになります。このケースでは、ローカルで「workdir」というフォルダを、「/workdir」というパスにマウントさせるので、「/workdir/…」と書きます。
コンテナの起動
ビルドが終わったらコンテナを起動します。カレントディレクトリを「docker」の一個上にして、
docker run --rm -it --gpus all -v $(pwd)/workdir:/workdir kohya-ss
次にコンテナ内に移動するので、スクリプトのパスに移動します。
cd sd-scripts
accelerateの設定をします。面倒だったらローカルで定義したyamlをマウントしてもいいかもしれません。
------------------------------------------------------------------------------------------------------------------------In which compute environment are you running?
This machine
------------------------------------------------------------------------------------------------------------------------Which type of machine are you using?
No distributed training
Do you want to run your training on CPU only (even if a GPU is available)? [yes/NO]:no
Do you wish to optimize your script with torch dynamo?[yes/NO]:no
Do you want to use DeepSpeed? [yes/NO]: no
What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:1
------------------------------------------------------------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)?
fp16
このようにしました。GPU IDの指定はallでもいいのですが、2個GPUがあって2個目のみ使いたかったので「1」と指定しました。
LoRAの訓練コマンド
accelerate launch --num_cpu_threads_per_process 1 train_network.py \
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
--dataset_config="/workdir/kiritan_training.toml" \
--output_dir="/workdir/lora_output" \
--output_name="sd15_lora" \
--save_model_as=safetensors \
--prior_loss_weight=1.0 \
--max_train_epochs=10 \
--learning_rate=1e-4 \
--optimizer_type="Lion8bit" \
--xformers \
--mixed_precision="fp16" \
--cache_latents \
--gradient_checkpointing \
--save_every_n_epochs=2 \
--seed=0 \
--network_module=networks.lora
このようにします。改行するときはバックスラッシュをお忘れなく!(これを忘れて設定が読み込まれずに2時間ぐらいハマりました)
オプティマイザは公式例にあるような「AdamW8bit」だとエラーになってしまったので、「Lion8bit」にしました。こっちのほうが高度なオプティマイザなのでこれでいいかなと思います。AdamWを使いたい場合は8bitやめて「AdamW」だと動きました。
ここがKohya版の強いところなのですが、「pretrained_model_name_or_path」はDiffusersのモデルIDを指定しても動きます。しかもそれはローカルにある必要はなく、Hugging Faceからよしなにダウンロードしてくれます。訓練コードのモデルの読み込みがDiffusersなので、このへんは得意です。
2080Tiで40分ぐらいで終わりました。もっとエポック数減らしてもいいかなと思います。
このようによく見慣れたLoRAができあがりました。Dockerでディレクトリごとマウントしているので、Dockerで吐き出したファイルを、ローカルからも使えます(永続化されます)。
推論してみる
safetensorsのLoRAと扱いは同じになりますので、オンザフライでマージさせます。やり方はこの記事参照。
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
import torch
import matplotlib.pyplot as plt
from safetensors.torch import load_file
def load_safetensors_lora(pipeline,
checkpoint_path,
LORA_PREFIX_UNET = "lora_unet",
LORA_PREFIX_TEXT_ENCODER = "lora_te",
alpha = 0.75):
# load LoRA weight from .safetensors
state_dict = load_file(checkpoint_path)
visited = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return pipeline
def main():
lora_path = "workdir/lora_output/sd15_lora.safetensors"
device = "cuda:0"
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id,
torch_dtype=torch.float16, safety_checker=None)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = load_safetensors_lora(pipe,
lora_path,
alpha=0.2)
pipe.to(device, torch.float16)
prompt = "kiritan, look at viewer, 1girl, best quality"
negative_prompt = "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
generator = torch.Generator(device).manual_seed(1234)
image = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=generator,
num_inference_steps=30, num_images_per_prompt=9).images
fig = plt.figure(figsize=(12, 12))
for i in range(9):
ax = fig.add_subplot(3, 3, i+1)
ax.imshow(image[i])
ax.axis("off")
plt.show()
if __name__ == "__main__":
main()
αをかなり落とさないといけなかったので、もしかしたらかなり過学習してしまったかもしれません。エポック数をもっと減らしても良かったかもしれませんね。
推論結果
推論時のプロンプトは「kiritan, look at viewer, 1girl, best quality」です。
SD1.5で訓練し、SD1.5で推論
α=0.2
元がアニメモデルでないのでこんな感じかなという具合です
SD1.5で訓練し、SomethingV2で推論
α=0.2
アニメモデルのSomethingV2_2で推論した結果です。結果はかなりうまくいっていて、SD1.5で訓練して別モデルに移植してもうまくいっているのが確認できます。
SomethingV2で訓練し、SD1.5で推論
α=0.3
こちらはLoRAの学習をアニメモデル(SomethingV2)で行い、SD1.5で推論した場合です。アニメ→きりたんだけでは、アニメ方向へのタスクの学習が行われず、LoRAの適用結果は実写のままでした。αを大きくすると画像が破綻しました。
SomethingV2で訓練し、SomethingV2で推論
こちらはLoRAの学習モデルと推論モデルが一致した結果です。これは分布シフトがないので一番きれいな結果になります。ここまでできるKohya版のLoRAすごいですね。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー