こしあん
2023-05-05

DiffusersでAttention Coupleを実装して領域別プロンプトを適用


Pocket
LINEで送る
Delicious にシェア

2.7k{icon} {views}


Stable DiffusionでAttention Couple(Attentionレイヤーをハックして領域別にプロンプトを適用する手法)をスマートな方法で実装できないかなと思ってやってみました。LoRAの実装を参考にして、AttnProcessorを書き換えるとスマートで、Attentionを自在にハッキングできるため将来性ありそうです。

Attention Coupleとは

※この記事は、相当機械学習詳しい人向けなので、WebUIをちょこっと触ってみたぐらいの99%はポカーンとなると思います。GUIがいい方は適当にWebUIの拡張機能でも探してください

領域別にプロンプトをかけるやり方の通称です。Latent Coupleの発展形です。

  • Latent Couple:Guidance Scaleの適用時にノイズの推定値を領域別にマスクをかけてかける
  • Attention Couple:これをAttentionレイヤーの内部で行い、Cross Attention適用後の特徴マップを領域別マスクをかける

詳しくはこちらを読んで下さい。こちらにLatent Coupleとの両方の実装が載っています。Latent Coupleの場合は、Stable DiffusionのPipeline内の

                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    n = noise_pred_text.shape[0] // 2
                    diff = noise_pred_text - noise_pred_uncond
                    diff[:n] = diff[:n] * spatial_mask + diff[n:] * (1-spatial_mask)
                    noise_pred = noise_pred_uncond + guidance_scale * diff

のようにすればいいので簡単です。ただ、Attention Coupleの場合は、Attentionレイヤーの中身をハックしないといけないのでちょっと大変です。

参考記事

この記事や実装が非常に参考になりました。

ただこの実装のままいくと、Diffusersのバージョンが上がってうまくいかなかったので今のバージョン(0.16.1)に対応可能な実装を目指しました。

また、Attentionの内部の値を取り出す方法については以下のツイートを参考にしました。

また、本実装はPyTorch 2.0.0以上を対象にしています。

Attentionをハックする

DiffusersでAttentionをハックするのは正直にやると大変です。なぜかというと、U-Netの中にあるAttentionモジュールを書き換えないといけないためです。U-NetとAttentionの両方を書き換えないといけなく、実装量が多いです。

そこで、Attentionの中だけアドホックで書き換える方法を考えます。ここで、LoRAの実装を参考にします。DiffusersでLoRAを適用する場合は、

pipe.unet.load_attn_procs(model_path)

公式ドキュメントより。LoRAはU-NetのAttentionの部分をハックする実装を行っています。つまり、LoRAの実装を参考にして、Attention Coupleを実装すればいいということになります。

DiffusersのU-Netのソースを解読していると、基底クラスの「UNet2DConditionLoadersMixin」では、LoRAのようなAttentionのモンキーパッチを受け付けています。

https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py#L96

load_attn_procsがLoRAとCustomDiffusionという2モデルにしか現在対応しておらず、Attention CoupleのようなアドホックなAttentionを適用することができません。

そこで、この実装を参考にしつつ独自のAttnProcessorを定義します。

Attention Processorとは

LoRA関係なくStable DiffusionのPipeline内にあるAttentionの処理の定義箇所です。例えば、普通のSDでも取り出すことができ、

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16)
print(pipe.unet.attn_processors)

この結果は、以下のようになります。結果はDictになります。

{'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F4F430>, 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F4E770>, 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F2F430>, 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F2E050>, 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F2D930>, 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F2CC40>, 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A339034460>, 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A339037220>, 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338FF6F80>, 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338FF5930>, 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338FF5C90>, 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338FF4F40>, 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A3390AE4D0>, 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A3390AEC20>, 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F983D0>, 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F98B80>, 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F9ABF0>, 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F9B2E0>, 'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A5E93C0>, 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A5E9A20>, 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A5EB3A0>, 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A5EBA00>, 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A6193C0>, 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A619A20>, 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A61B6A0>, 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A61BAF0>, 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A654100>, 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A654790>, 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A654310>, 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A33A654FA0>, 'mid_block.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F76AA0>, 'mid_block.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x000001A338F771F0>}

AttnProcessor2_0というのは、PyTorch 2の場合のAttention Processorで、Ver1の場合は別のAttention Processorが表示されていると思います。Attentionハッキングしたい場合は、この全てが書き換え対象なります。

例えば、適当なLoRAを読み込んでみると、

pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16)
pipe.unet.load_attn_procs("pytorch_lora_weights.bin", use_safetensors=False) # LoRAの読み込み
print(pipe.unet.attn_processors)
{'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': LoRAAttnProcessor(
  (to_q_lora): LoRALinearLayer(
    (down): Linear(in_features=320, out_features=4, bias=False)
    (up): Linear(in_features=4, out_features=320, bias=False)
  )
  (to_k_lora): LoRALinearLayer(
    (down): Linear(in_features=320, out_features=4, bias=False)
    (up): Linear(in_features=4, out_features=320, bias=False)
  )
  (to_v_lora): LoRALinearLayer(
    (down): Linear(in_features=320, out_features=4, bias=False)
    (up): Linear(in_features=4, out_features=320, bias=False)
  )
  (to_out_lora): LoRALinearLayer(
    (down): Linear(in_features=320, out_features=4, bias=False)
    (up): Linear(in_features=4, out_features=320, bias=False)
  )
), 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor': LoRAAttnProcessor(
  (to_q_lora): LoRALinearLayer(
    (down): Linear(in_features=320, out_features=4, bias=False)
    (up): Linear(in_features=4, out_features=320, bias=False)
  )
  (to_k_lora): LoRALinearLayer(
    (down): Linear(in_features=768, out_features=4, bias=False)
    (up): Linear(in_features=4, out_features=320, bias=False)
  )
  (to_v_lora): LoRALinearLayer(
    (down): Linear(in_features=768, out_features=4, bias=False)
    (up): Linear(in_features=4, out_features=320, bias=False)
  )
  (to_out_lora): LoRALinearLayer(
    (down): Linear(in_features=320, out_features=4, bias=False)
    (up): Linear(in_features=4, out_features=320, bias=False)
  )
), '
(以下略)

AttnProcessor2_0はLoRAAttnProcessorという専用のAttention Processorに変化しています。これをやりたいのです。

Attention Processorの抽出

unet.attn_processorsの実装を参考にします。

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py#L482

    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

このように再帰的にモジュールを抽出してDictに格納しています。格納する場合はこれをSet方向にやればいいですね。set_attn_processorという関数もあるので、うまく使えばこちらでもできるかもしれません。

実装

from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
import torch
from diffusers.models.attention_processor import AttnProcessor2_0, Attention
import torch.nn.functional as F
from torch import FloatTensor
import numpy as np

class AttentionCoupleProcessor(AttnProcessor2_0):
    def __init__(self, width, height, twoshot_weight=0.8):
        super().__init__()
        self.orig_height = height
        self.orig_width = width
        self.twoshot_weight = twoshot_weight

    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        inner_dim = hidden_states.shape[-1]
        is_cross_attn = encoder_hidden_states is not None

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        query = attn.to_q(hidden_states)
        if is_cross_attn:
            # copy query
            query_uncond, query_cond = query.chunk(2)
            query = torch.cat([query_uncond, query_cond, query_cond, query_cond], dim=0)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        head_dim = inner_dim // attn.heads
        query = query.view(query.shape[0], -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(key.shape[0], -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(value.shape[0], -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        ## cond * (1-w) + [left right] * w
        if is_cross_attn:
            rate = int((self.orig_height * self.orig_width // hidden_states.shape[1]) ** 0.5)
            height = self.orig_height // rate
            width = self.orig_width // rate            

            # cond, left, right, uncond = hidden_states.chunk(4)
            uncond, cond, left, right = hidden_states.chunk(4)
            # Convert to picture space
            left = left.reshape(left.shape[0], height, width, left.shape[2]) # (B, H, W, C)
            right = right.reshape(right.shape[0], height, width, right.shape[2]) # (B, H, W, C)
            # Couple : left right
            couple = torch.cat([ left[:,:,:width//2,:], right[:,:,width//2:,:]], dim=2)
            couple = couple.reshape(cond.shape[0], -1, cond.shape[2])

            cond = cond * (1-self.twoshot_weight) + couple * self.twoshot_weight
            hidden_states = torch.cat([uncond, cond])

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states

class AttentionCoupleDiffusionPipeline(StableDiffusionPipeline):
    def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, 
                       negative_prompt=None, prompt_embeds: FloatTensor | None = None, negative_prompt_embeds: FloatTensor | None = None):
        negative_prompt = [negative_prompt for i in range(len(prompt))]
        prompt_embeds = super()._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds, negative_prompt_embeds)
        uncond_prompt, cond_prompt = prompt_embeds.chunk(2) # 1+3
        n = uncond_prompt.shape[0] // 3
        prompt_embeds = torch.cat([uncond_prompt[:n], cond_prompt], dim=0)
        return prompt_embeds

    def _hack_attention_processor(self, name, module, processors, width, height):
        if hasattr(module, "set_processor"):
            processors[f"{name}.processor"] = module.processor
            module.set_processor(AttentionCoupleProcessor(width, height))

        for sub_name, child in module.named_children():
            self._hack_attention_processor(f"{name}.{sub_name}", child, processors, width, height)

        return processors

    def enable_attention_couple(self, width, height):
        processors = {}
        for name, module in self.unet.named_children():
            self._hack_attention_processor(name, module, processors, width, height)

なるべく追加コード量をへらすようにしました。AttnProcessor自体は特に学習可能なWeightを持っていないので、扱い方はかなり楽です。純粋に処理だけを書けばいいです。

私はAttnProcessor2_0の実装をコピペしてきて、こちらのNotebookの実装を追加しました。PyTorch Ver1の場合は、対応するAttention Processorを使ってください。

注意点なのですが、このAttention Processorはモデル内でSelf AttentionとCross Attentionの両方で呼ばれるので、Cross Attentionのときのみ適用してあげるように改造する必要があります。

Cross Attentionの場合は、プロンプトのText EmbeddingがKey, Valueに格納されますが、Self Attentionの場合は、Inputの特徴量がKey, Value格納されます(Self AttentionとCross Attentionの定義から)。Cross Attentionかどうかの判定は、単にencoder_hidden_statesがNoneでないかで判定すればいいです。

Attention Coupleの場合は、Attention内で座標計算する必要があるので、解像度等の情報はインスタンス変数に持たせるといいかもしれないですね。DownブロックやUpブロックがあるので、呼び出し時の特徴量の解像度は動的になります。

また、パイプライン側はStable Diffusionのパイプラインを継承して書きました。encode_promptの部分を継承してごちゃごちゃ!とすれば、Prompt=3、Negative Prompt=1というな状況にも対応できます。

プロンプトは、「全体、左側、右側」の順で入れます。

実行

def main(prompt,
         negative_prompt,
         device="cuda:1",
         width=640, height=896):
    pipe = AttentionCoupleDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion", torch_dtype=torch.float16)
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    # Disabling safety checker
    if pipe.safety_checker is not None:
        pipe.safety_checker = lambda images, **kwargs: (images, False)
    # AttentionCoupleの有効化
    pipe.enable_attention_couple(width, height)
    pipe.to(device)

    # Latentで指定
    generator = torch.Generator(device)
    generator.manual_seed(4545)
    latents = torch.randn((4, 4, height//8, width//8), generator=generator, device=device, dtype=torch.float16)

    image = pipe(prompt, negative_prompt=negative_prompt, latents=latents,
                    num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=4, output_type="numpy").images

    image = np.concatenate(image, axis=0).reshape(2, 2, height, width, 3).swapaxes(1, 2).reshape(2*height, 2*width, 3)
    with Image.fromarray((image*255.0).astype(np.uint8)) as img:
        img.save("result.png")


if __name__ == "__main__":
    prompt = [
        "masterpiece, best quality, 2girls",
        "masterpiece, best quality, 2girls, black hair, red eyes, school uniform, blue sailor collar, blue skirt",
        "masterpiece, best quality, 2girls, white hair, blue eyes, maid, maid headdress, skirt",
    ]
    negative_prompt = "worst quality, low quality, medium quality, deleted, lowres, comic, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"    
    main(prompt, negative_prompt)

参考記事の例をそのままお借りしました。

Latentの乱数を直接指定してあげると、プロンプトの数の整合性を取るのがちょっと楽です。

結果

Attentionハッキング、やり方がわかれば楽しいですね。今は左右のプロンプトのみですが、任意の数のレイヤーに拡張できるので、もっと凝ったことができるかもしれません。

ぱっと思いつくのは、Multi Diffusionのようなパノラマ生成をAttentionベースで行うものです(多分普通にMulti Diffusionよりいいと思います)。マスクを用意すればいいだけなのでこの応用ですね。

マージモデルでもいける

試しにマージモデルでも試してみましたが、なかなかいい感じになりました。

import copy

def merge_network(pipe_source, pipe_merge, attr, ratio):
    merge_net = copy.deepcopy(getattr(pipe_source, attr))
    pipe_source_params = dict(getattr(pipe_source, attr).named_parameters())
    pipe_merge_params = dict(getattr(pipe_merge, attr).named_parameters())
    for key, param in merge_net.named_parameters():
        x = pipe_source_params[key] * (1-ratio) + pipe_merge_params[key] * ratio
        param.data = x
    return merge_net

def main2(width=1920//2, height=1024//2):
    device = "cuda:1"    
    sd_pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
    pipe_pastel = StableDiffusionPipeline.from_pretrained("andite/pastel-mix",
                                                           torch_dtype=torch.float16)
    pipe = AttentionCoupleDiffusionPipeline.from_pretrained(
        "prompthero/openjourney-v4", torch_dtype=torch.float16)
    pipe.unet = merge_network(pipe, pipe_pastel, "unet", 0.75)
    pipe.text_encoder = sd_pipe.text_encoder
    pipe.vae = sd_pipe.vae
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.safety_checker = lambda images, **kwargs: (images, False)
    pipe.enable_vae_tiling()
    pipe.enable_attention_couple(width, height)
    pipe.to(device)

    generator = torch.Generator(device)
    generator.manual_seed(4545)
    latents = torch.randn((1, 4, height//8, width//8), generator=generator, device=device, dtype=torch.float16)

    prompt = [
        "a witch girl fighting a red dragon in galaxy space",
        "a red dragon in galaxy space",
        "1girl, cute girl, a witch of 12 years old battle in galaxy space, blond hair"
    ]
    for i in range(len(prompt)):
        prompt[i] += ", best quality, extremely detailed"
    negative_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " \
                    "fewer digits, cropped, worst quality, low quality"

    image = pipe(prompt, negative_prompt=negative_prompt, latents=latents,
                    num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1).images
    image[0].save("result2.png")


if __name__ == "__main__":
    main2()

色リークもおきてなくてなかなか楽しい手法ですね。

Pocket
Delicious にシェア



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です