こしあん
2021-02-24

ごちうさで始める線画の自動着色(2)~TFRecordの作成~


1.6k{icon} {views}

KaggleにあったGochiUsa_Facesデータセットを使って、ごちうさキャラの線画の自動着色で遊んでみました。この投稿では訓練の前段階としてTFRecordを作成していきます。

はじめに

前回のポストで、GochiUsa_FacesデータセットのEDAをしましたが、今回はモデル訓練のためのTFRecord作成をやっていきます。

ごちうさで始める線画の自動着色(1)~データセットのEDA~
https://blog.shikoan.com/gochiusa-01/

  • 今回のやること:TFRecordの作成
  • 次回以降やること:モデルの作成、訓練

なぜTFRecordなのか

今回の目的がTFRecordを使って実践的な訓練をしてみたかったという点があります(これが大きいです)。TensorFlowの推奨データ形式がTFRecordという点もあります。諸事情により、TFRecordを使わないで訓練するという選択肢もありですが、何かと避けがちなTFRecordに一度しっかり向き合ってみたかったのです。

作りたいモデル

線画を入力データとし、着色した画像を出力とするようなモデルを作ります。線画はデータセットに付属していませんが、カラー画像から人工的に作ります。「画像がどのキャラか」というラベル情報も使います。

今回TFRecordに記録する情報は次の3つです。

  1. もとのカラー画像
  2. 線画(1から人工的に作る)
  3. ラベル

Canny法によるエッジ検出

問題となるのは「どのようにカラー画像から線画を作るか」です。線画を作るとはエッジ検出の問題と捉えられるので、Canny法によるエッジ検出を使います。これはOpenCVに関数があります。

Canny法によるエッジ検出は次のようにします。Canny法には2つのスレッショルドがあり、ヒステリシスが存在する処理の最小値、最大値を示します(直感的な理解は後述)。

edge = cv2.Canny(img, 50, 200) # スレッショルドが2つある
Python

GochiUsa_Facesデータセット内の適当な画像をCanny法でエッジ検出し、元画像と並べて表示するコードは次の通りです。

def canny_simple():
    # ランダムに画像を1個選ぶ
    np.random.seed(21)
    files = sorted(glob.glob("archive/main_dataset/main_dataset/Chino/*"))
    np.random.shuffle(files)

    # オリジナル画像
    img = cv2.imread(files[0])
    # エッジ画像
    edge = cv2.Canny(img, 50, 200) # スレッショルドが2つある

    # 表示
    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(1, 2, 1)
    ax.imshow(img[...,::-1]) # BGR->RGB
    ax = fig.add_subplot(1, 2, 2)
    ax.imshow(edge, cmap="gray", vmin=0, vmax=255)

    plt.show()
Python

結果は次の通り。このようにエッジ検出できています。

スレッショルドの直感的な理解

Canny法の2つのスレッショルドを変更したらどうなるでしょうか。

def canny_threshold():
    # ランダムに画像を1個選ぶ
    np.random.seed(100)
    files = sorted(glob.glob("archive/main_dataset/main_dataset/Chino/*"))
    np.random.shuffle(files)

    # オリジナル画像
    img = cv2.imread(files[0])
    # エッジ画像
    thresholds = [50, 100, 150, 200, 400]

    fig = plt.figure(figsize=(12, 12))
    for i in range(5):
        for j in range(5):
            x = None
            if j > i:
                x = cv2.Canny(img, thresholds[i], thresholds[j])  # スレッショルドが2つある
            elif i == 3 and j == 1:
                x = img[...,::-1]  # 元画像

            if x is not None:
                ax = fig.add_subplot(4, 4, 4 * i + j)
                if x.ndim == 2:
                    ax.imshow(x, cmap="gray", vmin=0, vmax=255)
                    ax.set_title(f"{thresholds[i], thresholds[j]}")
                elif x.ndim == 3:
                    ax.imshow(x)
                ax.axis("off")
    plt.show()
Python

スレッショルドを「50、100、150、200、400」と変更してみました。これは実質的にハイパーパラメータとなるので、頑張るのだったらチューニングしてもいいと思います。

直感的な理解としては、第1と第2の値差が小さいほど細かい線を拾う傾向があります。ただ(50, 100)のケースだと背景の線まで拾うので、どこまで拾うのがよいのかは難しいところです。突っ込んでやるんだったら線画の抽出もNNベースになるのでしょうね。

この値が正しいかどうかはわかりませんが、今回は(50, 400)で固定してみました。

TFRecordを記録する

GochiUsa_Facesは画像解像度が無数にありますが、ほとんど正方形画像なので、固定サイズにリサイズしてから記録します。320×320にリサイズして、Data Augmentationをはさみながら256×256でNNに食わせます。短辺の解像度が320未満のサンプルは捨てます。

解像度さえ決まれば簡単ですね。サンプル単位で次のようなスキームにします。

  • カラー画像:uint8型 (320, 320, 3)
  • エッジ画像:uint8型 (320, 320, 1)
  • ラベル:float32型 (9, ) ←クラス数が9だから

ラベルは後々楽したいので、One-hotエンコーディングして記録します。TFRecordの作成コードは次の通り。

import tensorflow as tf
import numpy as np
import cv2
import glob

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

CHARACTERS = ["Blue Mountain", "Chino", "Chiya", "Cocoa", "Maya", "Megumi", "Mocha", "Rize", "Sharo"]

def serialize_sample(image, edge, class_idx):
    label_binary = np.eye(len(CHARACTERS)).astype(np.float32)[class_idx].tobytes()
    image_list = _bytes_feature(image.tobytes())
    edge_list = _bytes_feature(edge.tobytes())
    label_list = _bytes_feature(label_binary)
    proto = tf.train.Example(features=tf.train.Features(feature={
        "image": image_list,  # uint8, (320, 320, 3)
        "edge": edge_list, # uint8, (320, 320, 1)
        "label": label_list # float32, (9, )
    }))
    return proto.SerializeToString()

def write_record(root_dir, min_image_size):
    dirs = sorted(glob.glob(root_dir + "/*"))
    record_name = root_dir.replace("\\", "/").split("/")[-1]
    paths, classes = [], []
    for d in dirs:
        files = sorted(glob.glob(d + "/*"))
        base_dir = d.replace("\\", "/").split("/")[-1]
        class_idx = CHARACTERS.index(base_dir)
        if class_idx < 0:
            raise IndexError(f"class_idx < 0. class name = {base_dir}")

        for f in files:
            img = cv2.imread(f)
            if min(img.shape[:2]) < min_image_size:
                continue
            paths.append(f)
            classes.append(class_idx)

    np.random.seed(123)
    indices = np.random.permutation(len(paths))

    with tf.io.TFRecordWriter(record_name + ".tfrecord") as writer:
        for i in indices:
            img = cv2.imread(paths[i])
            img = cv2.resize(img, (min_image_size, min_image_size), interpolation=cv2.INTER_LANCZOS4)
            edge = cv2.Canny(img, 50, 400)[:,:,None]
            example = serialize_sample(img[:,:,::-1], edge, classes[i])
            writer.write(example)
Python

やり方はこの記事で行った通りです。NumPy配列を.tobytes()して記録しています。

途中でシャッフルを入れているのがポイントかもしれません。tf.data.Datasetでshuffleすると、ある局所的な区間のみシャッフルします。チノちゃんの画像が1万枚、シャロの画像が7000枚のように大量に連続していると、訓練中にシャッフルしてもキャラごとに団子になってしまうので、TFRecordの作成時にシャッフルを入れました。

write_recordを「write_record(“archive/main_dataset/main_dataset”, 320)」のようにデータセットのディレクトリに対して適用します。

TFRecordの可視化

TFRecordを可視化してみましょう。

import matplotlib.pyplot as plt

def deserialize_example(serialized_string):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'edge': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(serialized_string, image_feature_description)
    image = tf.reshape(tf.io.decode_raw(example["image"], tf.uint8), (320, 320, 3))
    edge = tf.reshape(tf.io.decode_raw(example["edge"], tf.uint8), (320, 320, 1))
    label = tf.io.decode_raw(example["label"], tf.float32)
    return image, edge, label

def read_record():
    dataset = tf.data.TFRecordDataset("main_dataset.tfrecord").map(deserialize_example).batch(4)
    for x in dataset:
        for i in range(4):
            fig = plt.figure(figsize=(10, 5))
            ax = fig.add_subplot(1, 2, 1)
            ax.imshow(x[0][i])
            ax = fig.add_subplot(1, 2, 2)
            ax.imshow(x[1][i], cmap="gray", vmin=0, vmax=255)
            ax.set_title(CHARACTERS[np.argmax(x[2][i])])
            plt.show()
Python

このようにしてTFRecordDatasetで使える形式になりました。

予告

次回の投稿では、モデルの設計や訓練をしていきます。

次回はこちら:https:blog.shikoan.com/gochiusa-03



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

  
Terraformで学ぶAWS(1):サーバーレスから始める再利用可能なインフラストラクチャ
  
AIアートの新時代2:Stable Diffusionの課題と動画生成の新潮流
  
コーディング侍:Pythonで学ぶ機械学習ソフトウェア開発の極意
  
AIアートの新時代:CLIPとStable Diffusionを活用した画像生成技術とその応用

Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です