Diffuserでのモデルマージを検証します。Stable Diffusionには、CLIP、U-Net、VAEの3つのモデルからなり、それぞれをマージの効果が違うので、面白い結果になります。また、CLIP(Text Encoder)のマージについても、OpenAIのものと比較検討してみます。
目次
Stable DiffusionでどうもFine-tuningしたモデルをマージするのが流行っているらしいので試してみました。Diffusersでやる場合は、「ただPyTorchで係数マージするだけなので楽だろう」とおもったらかなり楽でした。
モデルマージというのは実務的には結構よくやられていることですが、研究的にはまあまああります。私の知っている限りですが、
Stabel Diffusionのモデルは、以下の3つのモデルからなります
T2I-Adapterの図がわかりやすかったので、T2I-Adapterのリポジトリから引用します。
SDのCLIPはプロンプトを特徴量に変換しているだけなので、基本はFrozenではないかと思います(→実はそうではないことがわかりました)
マージできるかどうかの判定は「モデル構造が一緒ならマージ可能」です。モデル構造が一緒かどうかはパラメーター数を見るのが簡単で、
def count_params(pipe):
num_params = 0
for k, v in pipe.unet.named_parameters():
p = np.prod([x for x in v.shape])
num_params += p
return num_params
このようにStable DiffusionのPipelineのU-Netのパラメーターを計算するのがわかりやすいと思います。SD1.5とSD2.0系では、U-Netのパラメーター数が異なるためマージできません。
直接絵に関わりそうなのは、U-NetとVAEなので、まずはそこの部分をマージしてみます。コードは以下の通りです。
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
import torch
import numpy as np
import copy
import torchvision
from compel import Compel
def count_params(pipe):
num_params = 0
for k, v in pipe.unet.named_parameters():
p = np.prod([x for x in v.shape])
num_params += p
return num_params
def merge_network(pipe_source, pipe_merge, attr, ratio):
merge_net = copy.deepcopy(getattr(pipe_source, attr))
pipe_source_params = dict(getattr(pipe_source, attr).named_parameters())
pipe_merge_params = dict(getattr(pipe_merge, attr).named_parameters())
for key, param in merge_net.named_parameters():
x = pipe_source_params[key] * (1-ratio) + pipe_merge_params[key] * ratio
param.data = x
return merge_net
def main():
merge_pairs = {
"journey-counter": ["prompthero/openjourney-v4", "gsdf/Counterfeit-V2.0"],
"journey-sd": ["prompthero/openjourney-v4", "runwayml/stable-diffusion-v1-5"],
"counter-journey": ["gsdf/Counterfeit-V2.0", "prompthero/openjourney-v4"],
"counter-sd": ["gsdf/Counterfeit-V2.0", "runwayml/stable-diffusion-v1-5"],
"sd-journey": ["runwayml/stable-diffusion-v1-5", "prompthero/openjourney-v4"],
"sd-counter": ["runwayml/stable-diffusion-v1-5", "gsdf/Counterfeit-V2.0"],
}
for i, (merge_key, merge_pair) in enumerate(merge_pairs.items()):
pipe1 = StableDiffusionPipeline.from_pretrained(merge_pair[0], torch_dtype=torch.float16, use_safetensors=True)
pipe1.scheduler = UniPCMultistepScheduler.from_config(pipe1.scheduler.config, use_safetensors=True)
pipe2 = StableDiffusionPipeline.from_pretrained(merge_pair[1], torch_dtype=torch.float16, use_safetensors=True)
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config, use_safetensors=True)
images = []
for unet_ratio in [0, 0.5, 1]:
for vae_ratio in [0, 0.5, 1]:
merge_unet = merge_network(pipe1, pipe2, "unet", unet_ratio)
merge_vae = merge_network(pipe1, pipe2, "vae", vae_ratio)
merge_pipe = StableDiffusionPipeline(
vae=merge_vae,
text_encoder=pipe1.text_encoder,
tokenizer=pipe1.tokenizer,
unet=merge_unet,
scheduler=UniPCMultistepScheduler.from_config(pipe1.scheduler.config),
safety_checker=lambda images, **kwargs: (images, False),
feature_extractor=pipe1.feature_extractor
)
device = "cuda:1"
generator = torch.Generator(device).manual_seed(1234)
prompt = "1girl++++, solo, traveler, rule of thirds, teens, dress++, silver hair, side view, face in profile, praying, anime face, stand on a high hill, "\
"looking down into distance, pasture, flower garden, small town, sunshine, clear sky, flowing clouds, masterpiece, best quality, detailed face"
negative_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, multiple girls, multiple people"
merge_pipe.to(device)
compel_proc = Compel(tokenizer=merge_pipe.tokenizer, text_encoder=merge_pipe.text_encoder)
conditioning = compel_proc(prompt)
neg_conditioning = compel_proc(negative_prompt)
image = merge_pipe(prompt_embeds=conditioning, negative_prompt_embeds=neg_conditioning, generator=generator, output_type="numpy", num_images_per_prompt=1, width=512, height=768).images
images.append(image)
images = np.concatenate(images)
images = torch.from_numpy(images).permute(0, 3, 1, 2)
torchvision.utils.save_image(images, f"merge_{i}_{merge_key}.jpg", quality=92, nrow=3, normalize=False)
if __name__ == "__main__":
main()
Compelを入れてプロンプトに重み付けできるようにしています。
結果の見方は以下の通りです。
「A-B」という形でモデルマージしたとして、
VAE 0 | VAE 0.5 | VAE 1 | |
---|---|---|---|
U-Net 0 | U-Net A:B=1:0 VAE A:B=1:0 |
U-Net A:B=1:0 VAE A:B=0.5:0.5 |
U-Net A:B=1:0 VAE A:B=0:1 |
U-Net 0.5 | U-Net A:B=0.5:0.5 VAE A:B=1:0 |
U-Net A:B=0.5:0.5 VAE A:B=0.5:0.5 |
U-Net A:B=0.5:0.5 VAE A:B=0:1 |
U-Net 1 | U-Net A:B=0:1 VAE A:B=1:0 |
U-Net A:B=0:1 VAE A:B=0.5:0.5 |
U-Net A:B=0:1 VAE A:B=0:1 |
とします。横がVAEの変化で、縦がU-Netの変化です。
3つのモデルを全パターンでマージしてみます。
これらはすべてモデル構造が等しいので、マージ可能です。
merge_pairs = {
"journey-counter": ["prompthero/openjourney-v4", "gsdf/Counterfeit-V2.0"],
"journey-sd": ["prompthero/openjourney-v4", "runwayml/stable-diffusion-v1-5"],
"counter-journey": ["gsdf/Counterfeit-V2.0", "prompthero/openjourney-v4"],
"counter-sd": ["gsdf/Counterfeit-V2.0", "runwayml/stable-diffusion-v1-5"],
"sd-journey": ["runwayml/stable-diffusion-v1-5", "prompthero/openjourney-v4"],
"sd-counter": ["runwayml/stable-diffusion-v1-5", "gsdf/Counterfeit-V2.0"],
}
このように6通りの組み合わせが考えられます。
OpenJourneyのハイコントラストな感じから、Counterfeitの淡い感じに移っていきました。U-Netを半々ぐらいでマージすると背景がいい感じになりますね。
ちょっとこれがおかしな結果です。Stable Diffusionでは実写のほうが入っているため、右下のようなケースはかなり実写に近くならなんくてはいけません。なので、マージしていないCLIPの部分になにかトリックがあるのだと思います。
ケース1のSourceとTargetを入れ替えた例です。同じアニメが得意なもの同士なので、違和感は特にないです。CounterfeitはデフォルトのVAEがあまりよくなく、OpenJourneyのVAEを使ってあげるといい感じになりそうですね。VAEはモデル関係なく好きなの使っても良さそうです。
フルマージすると、実写にいきました。OpenJourneyがちょっとおかしいのかもしれません。SDのVAEはなかなかで、実写っぽいタッチのアニメ絵を書きたいときはSDのVAEを使ってもいいかもしれません。
ベースがSDなので、左上→右下でアニメ調に映るはずです。このケースでは、CLIP(Text Encoder)はSD1.5のを使っていますが、アニメに特化したCLIPでなくても大丈夫そうですね。
CLIPって広範囲のデータで学習したほうがいいものができる(と勝手に思っている)ので、ここをFine-tuningしてしまうと、破滅的忘却の影響が出てあんまりよくないと思います。
ケース2の逆のケースですが、こっちのほうがプロンプトの意味をちゃんと解釈してそうなので、OpenJourneyのCLIPがちょっとおかしいのかもしれません。
これも納得行く結果になりました。Counterfeit-V2.0はVAEを別なのに置き換えてみると良さそうです。
以下のことがわかりました
今度はU-NetとCLIPでマージしてみます。結果の見方は以下のようになります。VAEはStable Diffusion1.5ので固定します。
CLIP 0 | CLIP 0.5 | CLIP 1 | |
---|---|---|---|
U-Net 0 | U-Net A:B=1:0 CLIP A:B=1:0 |
U-Net A:B=1:0 CLIP A:B=0.5:0.5 |
U-Net A:B=1:0 CLIP A:B=0:1 |
U-Net 0.5 | U-Net A:B=0.5:0.5 CLIP A:B=1:0 |
U-Net A:B=0.5:0.5 CLIP A:B=0.5:0.5 |
U-Net A:B=0.5:0.5 CLIP A:B=0:1 |
U-Net 1 | U-Net A:B=0:1 CLIP A:B=1:0 |
U-Net A:B=0:1 CLIP A:B=0.5:0.5 |
U-Net A:B=0:1 CLIP A:B=0:1 |
コードはこのようになります
def main_clip():
merge_pairs = {
"journey-counter": ["prompthero/openjourney-v4", "gsdf/Counterfeit-V2.0"],
"journey-sd": ["prompthero/openjourney-v4", "runwayml/stable-diffusion-v1-5"],
"counter-journey": ["gsdf/Counterfeit-V2.0", "prompthero/openjourney-v4"],
"counter-sd": ["gsdf/Counterfeit-V2.0", "runwayml/stable-diffusion-v1-5"],
"sd-journey": ["runwayml/stable-diffusion-v1-5", "prompthero/openjourney-v4"],
"sd-counter": ["runwayml/stable-diffusion-v1-5", "gsdf/Counterfeit-V2.0"],
}
base_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True)
for i, (merge_key, merge_pair) in enumerate(merge_pairs.items()):
pipe1 = StableDiffusionPipeline.from_pretrained(merge_pair[0], torch_dtype=torch.float16, use_safetensors=True)
pipe1.scheduler = UniPCMultistepScheduler.from_config(pipe1.scheduler.config, use_safetensors=True)
pipe2 = StableDiffusionPipeline.from_pretrained(merge_pair[1], torch_dtype=torch.float16, use_safetensors=True)
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config, use_safetensors=True)
images = []
for unet_ratio in [0, 0.5, 1]:
for clip_ratio in [0, 0.5, 1]:
merge_unet = merge_network(pipe1, pipe2, "unet", unet_ratio)
merge_clip = merge_network(pipe1, pipe2, "text_encoder", clip_ratio)
merge_pipe = StableDiffusionPipeline(
vae=base_pipe.vae,
text_encoder=merge_clip,
tokenizer=pipe1.tokenizer,
unet=merge_unet,
scheduler=UniPCMultistepScheduler.from_config(pipe1.scheduler.config),
safety_checker=lambda images, **kwargs: (images, False),
feature_extractor=pipe1.feature_extractor
)
device = "cuda:1"
generator = torch.Generator(device).manual_seed(1234)
prompt = "1girl++++, solo, traveler, rule of thirds, teens, dress++, silver hair, side view, face in profile, praying, anime face, stand on a high hill, "\
"looking down into distance, pasture, flower garden, small town, sunshine, clear sky, flowing clouds, masterpiece, best quality, detailed face"
negative_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, multiple girls, multiple people"
merge_pipe.to(device)
compel_proc = Compel(tokenizer=merge_pipe.tokenizer, text_encoder=merge_pipe.text_encoder)
conditioning = compel_proc(prompt)
neg_conditioning = compel_proc(negative_prompt)
image = merge_pipe(prompt_embeds=conditioning, negative_prompt_embeds=neg_conditioning, generator=generator, output_type="numpy", num_images_per_prompt=1, width=512, height=768).images
images.append(image)
images = np.concatenate(images)
images = torch.from_numpy(images).permute(0, 3, 1, 2)
torchvision.utils.save_image(images, f"merge_{i+10}_{merge_key}.jpg", quality=92, nrow=3, normalize=False)
先ほどは(U-NetとVAEのマージ)、
に一貫性がありませんでした。今回はどうでしょうか?
このようにそれぞれ反転した結果となり,一貫性が保証されました。つまり、CLIPのFine-tuningも行われており、CLIP空間の変化が画風や構図・細部の変化をもたらしていたというこのを表します。
面白いのがケース2の下段のケースで、U-Netを完全にSDのものにしても(マージ比率1)、CLIPがOpenJourneyのもの(マージ比率0)、半々マージ(マージ比率0.5)ならアニメ調のままということです。
もちろんマージのケースでは、前景の芝桜の部分が後ろの背景に転移していたり、ちょっとおかしいところはあります。しかし、画風のコントロールの支配的な部分はU-Netではなく、CLIPということが示唆されます。したがって、CLIP部分にAdapter的なアフィン変換のレイヤーを追加して、他全部フリーズして訓練してもある程度行けるのではないかと思われます。
他のケースも見ていきましょう。ソースとターゲットを反転しても同じなので、特定の組み合わせのみみます。
U-NetとCLIPの効きが面白いように相互作用かかっています。U-Netでカメラアングルが変わるのは納得ですが、CLIPを変えると髪型や服装が微妙に変わっていくのが面白いですね。
CLIPがポーズに効いているのが面白いですね。多分モデルのバイアスなのではないかと思います。
2段目の右のように、CLIPを完全にSDにしてしまうと、牧草地のプロンプトがもう少し強く出ています。
ここであることに気づきます。元のOpenAIのCLIPってものすごいEmbeddingいいんだから、そっちとマージしちゃえばいいやん。
SD1.5で使っているCLIPはなにかを探す必要があります。いまいち情報なくてソースコード買得したのですが、「openai/clip-vit-large-patch14」でした。
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.freeze()
SD1.5のソースよりです。見てわかるように、本来のSDってCLIP部分はフリーズしているので、CLIP部分に影響が出ることは基本ないはずなんですよね…
マージできるか調べていきましょう。このようにStable DiffusionとOpenAI CLIP間で、係数とパラメーター名があっているかどうかチェックします。
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
import torch
import numpy as np
import copy
import torchvision
from compel import Compel
from transformers import CLIPModel
def count_params(pipe):
num_params = 0
for k, v in pipe.named_parameters():
p = np.prod([x for x in v.shape])
num_params += p
return num_params
def check_models():
pipe1 = StableDiffusionPipeline.from_pretrained("prompthero/openjourney-v4", torch_dtype=torch.float16, use_safetensors=True)
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
print("--num parameters --")
print(count_params(model.text_model))
print(count_params(pipe1.text_encoder))
print("--named parameters --")
param_name_sd = list(k for k, _ in pipe1.text_encoder.named_parameters())
param_name_clip = list(k for k, _ in model.text_model.named_parameters())
print(param_name_sd)
print(param_name_clip)
print("-- is same name ?")
print(all([s == "text_model."+c for s, c in zip(param_name_sd, param_name_clip)]))
結果は以下のようになります。
--num parameters --
123060480
123060480
--named parameters --
['text_model.embeddings.token_embedding.weight', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.final_layer_norm.weight', 'text_model.final_layer_norm.bias']
['embeddings.token_embedding.weight', 'embeddings.position_embedding.weight', 'encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.k_proj.bias', 'encoder.layers.0.self_attn.v_proj.weight', 'encoder.layers.0.self_attn.v_proj.bias', 'encoder.layers.0.self_attn.q_proj.weight', 'encoder.layers.0.self_attn.q_proj.bias', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn.out_proj.bias', 'encoder.layers.0.layer_norm1.weight', 'encoder.layers.0.layer_norm1.bias', 'encoder.layers.0.mlp.fc1.weight', 'encoder.layers.0.mlp.fc1.bias', 'encoder.layers.0.mlp.fc2.weight', 'encoder.layers.0.mlp.fc2.bias', 'encoder.layers.0.layer_norm2.weight', 'encoder.layers.0.layer_norm2.bias', 'encoder.layers.1.self_attn.k_proj.weight', 'encoder.layers.1.self_attn.k_proj.bias', 'encoder.layers.1.self_attn.v_proj.weight', 'encoder.layers.1.self_attn.v_proj.bias', 'encoder.layers.1.self_attn.q_proj.weight', 'encoder.layers.1.self_attn.q_proj.bias', 'encoder.layers.1.self_attn.out_proj.weight', 'encoder.layers.1.self_attn.out_proj.bias', 'encoder.layers.1.layer_norm1.weight', 'encoder.layers.1.layer_norm1.bias', 'encoder.layers.1.mlp.fc1.weight', 'encoder.layers.1.mlp.fc1.bias', 'encoder.layers.1.mlp.fc2.weight', 'encoder.layers.1.mlp.fc2.bias', 'encoder.layers.1.layer_norm2.weight', 'encoder.layers.1.layer_norm2.bias', 'encoder.layers.2.self_attn.k_proj.weight', 'encoder.layers.2.self_attn.k_proj.bias', 'encoder.layers.2.self_attn.v_proj.weight', 'encoder.layers.2.self_attn.v_proj.bias', 'encoder.layers.2.self_attn.q_proj.weight', 'encoder.layers.2.self_attn.q_proj.bias', 'encoder.layers.2.self_attn.out_proj.weight', 'encoder.layers.2.self_attn.out_proj.bias', 'encoder.layers.2.layer_norm1.weight', 'encoder.layers.2.layer_norm1.bias', 'encoder.layers.2.mlp.fc1.weight', 'encoder.layers.2.mlp.fc1.bias', 'encoder.layers.2.mlp.fc2.weight', 'encoder.layers.2.mlp.fc2.bias', 'encoder.layers.2.layer_norm2.weight', 'encoder.layers.2.layer_norm2.bias', 'encoder.layers.3.self_attn.k_proj.weight', 'encoder.layers.3.self_attn.k_proj.bias', 'encoder.layers.3.self_attn.v_proj.weight', 'encoder.layers.3.self_attn.v_proj.bias', 'encoder.layers.3.self_attn.q_proj.weight', 'encoder.layers.3.self_attn.q_proj.bias', 'encoder.layers.3.self_attn.out_proj.weight', 'encoder.layers.3.self_attn.out_proj.bias', 'encoder.layers.3.layer_norm1.weight', 'encoder.layers.3.layer_norm1.bias', 'encoder.layers.3.mlp.fc1.weight', 'encoder.layers.3.mlp.fc1.bias', 'encoder.layers.3.mlp.fc2.weight', 'encoder.layers.3.mlp.fc2.bias', 'encoder.layers.3.layer_norm2.weight', 'encoder.layers.3.layer_norm2.bias', 'encoder.layers.4.self_attn.k_proj.weight', 'encoder.layers.4.self_attn.k_proj.bias', 'encoder.layers.4.self_attn.v_proj.weight', 'encoder.layers.4.self_attn.v_proj.bias', 'encoder.layers.4.self_attn.q_proj.weight', 'encoder.layers.4.self_attn.q_proj.bias', 'encoder.layers.4.self_attn.out_proj.weight', 'encoder.layers.4.self_attn.out_proj.bias', 'encoder.layers.4.layer_norm1.weight', 'encoder.layers.4.layer_norm1.bias', 'encoder.layers.4.mlp.fc1.weight', 'encoder.layers.4.mlp.fc1.bias', 'encoder.layers.4.mlp.fc2.weight', 'encoder.layers.4.mlp.fc2.bias', 'encoder.layers.4.layer_norm2.weight', 'encoder.layers.4.layer_norm2.bias', 'encoder.layers.5.self_attn.k_proj.weight', 'encoder.layers.5.self_attn.k_proj.bias', 'encoder.layers.5.self_attn.v_proj.weight', 'encoder.layers.5.self_attn.v_proj.bias', 'encoder.layers.5.self_attn.q_proj.weight', 'encoder.layers.5.self_attn.q_proj.bias', 'encoder.layers.5.self_attn.out_proj.weight', 'encoder.layers.5.self_attn.out_proj.bias', 'encoder.layers.5.layer_norm1.weight', 'encoder.layers.5.layer_norm1.bias', 'encoder.layers.5.mlp.fc1.weight', 'encoder.layers.5.mlp.fc1.bias', 'encoder.layers.5.mlp.fc2.weight', 'encoder.layers.5.mlp.fc2.bias', 'encoder.layers.5.layer_norm2.weight', 'encoder.layers.5.layer_norm2.bias', 'encoder.layers.6.self_attn.k_proj.weight', 'encoder.layers.6.self_attn.k_proj.bias', 'encoder.layers.6.self_attn.v_proj.weight', 'encoder.layers.6.self_attn.v_proj.bias', 'encoder.layers.6.self_attn.q_proj.weight', 'encoder.layers.6.self_attn.q_proj.bias', 'encoder.layers.6.self_attn.out_proj.weight', 'encoder.layers.6.self_attn.out_proj.bias', 'encoder.layers.6.layer_norm1.weight', 'encoder.layers.6.layer_norm1.bias', 'encoder.layers.6.mlp.fc1.weight', 'encoder.layers.6.mlp.fc1.bias', 'encoder.layers.6.mlp.fc2.weight', 'encoder.layers.6.mlp.fc2.bias', 'encoder.layers.6.layer_norm2.weight', 'encoder.layers.6.layer_norm2.bias', 'encoder.layers.7.self_attn.k_proj.weight', 'encoder.layers.7.self_attn.k_proj.bias', 'encoder.layers.7.self_attn.v_proj.weight', 'encoder.layers.7.self_attn.v_proj.bias', 'encoder.layers.7.self_attn.q_proj.weight', 'encoder.layers.7.self_attn.q_proj.bias', 'encoder.layers.7.self_attn.out_proj.weight', 'encoder.layers.7.self_attn.out_proj.bias', 'encoder.layers.7.layer_norm1.weight', 'encoder.layers.7.layer_norm1.bias', 'encoder.layers.7.mlp.fc1.weight', 'encoder.layers.7.mlp.fc1.bias', 'encoder.layers.7.mlp.fc2.weight', 'encoder.layers.7.mlp.fc2.bias', 'encoder.layers.7.layer_norm2.weight', 'encoder.layers.7.layer_norm2.bias', 'encoder.layers.8.self_attn.k_proj.weight', 'encoder.layers.8.self_attn.k_proj.bias', 'encoder.layers.8.self_attn.v_proj.weight', 'encoder.layers.8.self_attn.v_proj.bias', 'encoder.layers.8.self_attn.q_proj.weight', 'encoder.layers.8.self_attn.q_proj.bias', 'encoder.layers.8.self_attn.out_proj.weight', 'encoder.layers.8.self_attn.out_proj.bias', 'encoder.layers.8.layer_norm1.weight', 'encoder.layers.8.layer_norm1.bias', 'encoder.layers.8.mlp.fc1.weight', 'encoder.layers.8.mlp.fc1.bias', 'encoder.layers.8.mlp.fc2.weight', 'encoder.layers.8.mlp.fc2.bias', 'encoder.layers.8.layer_norm2.weight', 'encoder.layers.8.layer_norm2.bias', 'encoder.layers.9.self_attn.k_proj.weight', 'encoder.layers.9.self_attn.k_proj.bias', 'encoder.layers.9.self_attn.v_proj.weight', 'encoder.layers.9.self_attn.v_proj.bias', 'encoder.layers.9.self_attn.q_proj.weight', 'encoder.layers.9.self_attn.q_proj.bias', 'encoder.layers.9.self_attn.out_proj.weight', 'encoder.layers.9.self_attn.out_proj.bias', 'encoder.layers.9.layer_norm1.weight', 'encoder.layers.9.layer_norm1.bias', 'encoder.layers.9.mlp.fc1.weight', 'encoder.layers.9.mlp.fc1.bias', 'encoder.layers.9.mlp.fc2.weight', 'encoder.layers.9.mlp.fc2.bias', 'encoder.layers.9.layer_norm2.weight', 'encoder.layers.9.layer_norm2.bias', 'encoder.layers.10.self_attn.k_proj.weight', 'encoder.layers.10.self_attn.k_proj.bias', 'encoder.layers.10.self_attn.v_proj.weight', 'encoder.layers.10.self_attn.v_proj.bias', 'encoder.layers.10.self_attn.q_proj.weight', 'encoder.layers.10.self_attn.q_proj.bias', 'encoder.layers.10.self_attn.out_proj.weight', 'encoder.layers.10.self_attn.out_proj.bias', 'encoder.layers.10.layer_norm1.weight', 'encoder.layers.10.layer_norm1.bias', 'encoder.layers.10.mlp.fc1.weight', 'encoder.layers.10.mlp.fc1.bias', 'encoder.layers.10.mlp.fc2.weight', 'encoder.layers.10.mlp.fc2.bias', 'encoder.layers.10.layer_norm2.weight', 'encoder.layers.10.layer_norm2.bias', 'encoder.layers.11.self_attn.k_proj.weight', 'encoder.layers.11.self_attn.k_proj.bias', 'encoder.layers.11.self_attn.v_proj.weight', 'encoder.layers.11.self_attn.v_proj.bias', 'encoder.layers.11.self_attn.q_proj.weight', 'encoder.layers.11.self_attn.q_proj.bias', 'encoder.layers.11.self_attn.out_proj.weight', 'encoder.layers.11.self_attn.out_proj.bias', 'encoder.layers.11.layer_norm1.weight', 'encoder.layers.11.layer_norm1.bias', 'encoder.layers.11.mlp.fc1.weight', 'encoder.layers.11.mlp.fc1.bias', 'encoder.layers.11.mlp.fc2.weight', 'encoder.layers.11.mlp.fc2.bias', 'encoder.layers.11.layer_norm2.weight', 'encoder.layers.11.layer_norm2.bias', 'final_layer_norm.weight', 'final_layer_norm.bias']
-- is same name ?
True
係数サイズも名前も一緒なので(ちょっといじる必要ありますが)、マージできるということになります。
やり方はこれまでと同じです。
def merge_clip_parameters(pipe_source, openai_clip, ratio):
merge_net = copy.deepcopy(getattr(pipe_source, "text_encoder"))
pipe_source_params = dict(getattr(pipe_source, "text_encoder").named_parameters())
pipe_merge_params = dict(getattr(openai_clip, "text_model").named_parameters())
for key, param in merge_net.named_parameters():
x = pipe_source_params[key] * (1-ratio) + pipe_merge_params[key.replace("text_model.", "")] * ratio
param.data = x
return merge_net
def merge_network(pipe_source, pipe_merge, attr, ratio):
merge_net = copy.deepcopy(getattr(pipe_source, attr))
pipe_source_params = dict(getattr(pipe_source, attr).named_parameters())
pipe_merge_params = dict(getattr(pipe_merge, attr).named_parameters())
for key, param in merge_net.named_parameters():
x = pipe_source_params[key] * (1-ratio) + pipe_merge_params[key] * ratio
param.data = x
return merge_net
def merge_clip():
merge_pairs = {
"journey-journey": ["prompthero/openjourney-v4", "prompthero/openjourney-v4"],
"journey-counter": ["prompthero/openjourney-v4", "gsdf/Counterfeit-V2.0"],
"journey-sd": ["prompthero/openjourney-v4", "runwayml/stable-diffusion-v1-5"],
"counter-counter": ["gsdf/Counterfeit-V2.0", "gsdf/Counterfeit-V2.0"],
"counter-sd": ["gsdf/Counterfeit-V2.0", "runwayml/stable-diffusion-v1-5"],
"sd-sd": ["runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5"],
}
device = "cuda:1"
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True)
openai_clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
for i, (merge_key, merge_pair) in enumerate(merge_pairs.items()):
pipe1 = StableDiffusionPipeline.from_pretrained(merge_pair[0], torch_dtype=torch.float16, use_safetensors=True)
pipe1.scheduler = UniPCMultistepScheduler.from_config(pipe1.scheduler.config, use_safetensors=True)
pipe2 = StableDiffusionPipeline.from_pretrained(merge_pair[1], torch_dtype=torch.float16, use_safetensors=True)
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config, use_safetensors=True)
images = []
for clip_ratio in [0, 0.33, 0.66, 1]:
merge_clip = merge_clip_parameters(pipe1, openai_clip, clip_ratio)
merge_unet = merge_network(pipe1, pipe2, "unet", 0.5)
merge_pipe = StableDiffusionPipeline(
vae=sd_pipe.vae,
text_encoder=merge_clip,
tokenizer=pipe1.tokenizer,
unet=merge_unet,
scheduler=UniPCMultistepScheduler.from_config(pipe1.scheduler.config),
safety_checker=lambda images, **kwargs: (images, False),
feature_extractor=pipe1.feature_extractor
)
generator = torch.Generator(device).manual_seed(1234)
prompt = "1girl++++, solo, traveler, rule of thirds, teens, dress++, silver hair, side view, face in profile, praying, anime face, stand on a high hill, "\
"looking down into distance, pasture, flower garden, small town, sunshine, clear sky, flowing clouds, masterpiece, best quality, detailed face"
negative_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, multiple girls, multiple people"
merge_pipe.to(device)
compel_proc = Compel(tokenizer=merge_pipe.tokenizer, text_encoder=merge_pipe.text_encoder)
conditioning = compel_proc(prompt)
neg_conditioning = compel_proc(negative_prompt)
image = merge_pipe(prompt_embeds=conditioning, negative_prompt_embeds=neg_conditioning, generator=generator, output_type="numpy", num_images_per_prompt=1, width=512, height=768).images
images.append(image)
images = np.concatenate(images)
images = torch.from_numpy(images).permute(0, 3, 1, 2)
torchvision.utils.save_image(images, f"merge_{i+20}_{merge_key}.jpg", quality=92, nrow=4, normalize=False)
ここで行う実験は以下のとおりです。
横軸はCLIPのマージ比率で、6パターン実験します。各パターンはU-Netのマージ比率の違いです。
OpenJourney自体がText EncoderをFine-tuningしているとしか考えられません。ただ、CLIPが0か1にしてしまうと背景が似たような感じになるのはちょっと興味深いです。
SD全く関係ないモデルでも、SDが使っているOpenAI CLIPをマージしてうまくいくのが面白いです
これはCLIP関係なく、SD1.5とマージしたときと同じ結果になるはずです。CLIPのマージ比率を0.66まで上げてもアニメ調が維持されるの面白いですね。
Counterfeit単体だとそこまでCLIPの影響はないように見えます。
しかし、U-Netにマージが入ると、異なる結果になりました。おそらくU-Netの係数がマージで曖昧になったので、Text Embeddingの影響が大きくなったのではないかと思います。CLIPをOpenAIのものに置き換えてもあまり悪影響はないので、置き換えてしまうのも一つの手かと思います。
SD1.5のCLIPとOpenAI CLIPは係数が同じなので、影響がないはずです。それが実証されました。
この実験からわかることは以下の通りです。
ただあくまでこれはSD1.5までの話で、SD2系は使っているCLIPが変わるので、また別途検証する必要がありそうです。
また、スタイルを限定するために特定のワードを入れるモデルでは検証していないので、CLIPを入れ替えると大きく結果が異なる可能性があります。特定ワードの部分がCLIPで吸収しているのか、U-NetのAttentionで吸収しているのかよくわかりませんが、CLIPを入れ替えるとスタイルの限定がうまくいかなかくなるということも十分考えられます。
逆の見方をすれば、各Fine-tuningモデルのCLIPの部分だけとってきて、プロンプトなり単語なり入れて、どのように変わったのか定量化するというこもと可能なはずです。
View Comments