こしあん
2022-06-24

YOLOX+ByteTrackでお手軽トラッキングやってみた


7.4k{icon} {views}

YOLOX+ByteTrackでトラッキングをやってみました。「とりあえずサクッと訓練済みモデルでトラッキングしたい」というときに使える手法です。

背景

トラッキング、需要ある割に個人的に食わず嫌いしていて、「ちょっと動かしておかないとまずいよなー」と思ってサクッと訓練済みモデルを動かしてみました。

参考

高橋かずひとさんのブログとリポジトリがとても便利なので、このサンプルコードをアレンジして動かしてみました。

https://kazuhito00.hatenablog.com/entry/2022/02/19/001641
https://github.com/Kazuhito00/yolox-bytetrack-mcmot-sample

想定環境

Google ColabのCPUインスタンスを想定します。

判定対象動画

以下の動画をトラッキングします。フリー素材からとってきたものです。

セットアップ

前提ライブラリをインストールします

!pip install onnxruntime lap cython_bbox

リポジトリをクローンします

!git clone https://github.com/Kazuhito00/yolox-bytetrack-mcmot-sample
%cd yolox-bytetrack-mcmot-sample

これはどちらでもいいですが、デフォルトのモデルがYOLOX-Tinyで、今回の用途としては軽すぎるので、YOLOX-Sを使います。YOLOXの公式リポジトリからダウンロードします。もともとこのリポジトリも同じ場所からダウンロードしているらしく、任意のサイズのYOLOXを利用できます。

!wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.onnx -P model

サンプルプログラムの改変

コードのsample.pyですが、ウィンドウ操作が入っておりColabで動かすとエラーになってしまうため、動画を書き出す方式に書き換えます。また、どの程度の精度で見ているのか詳しく見たかったので、トラッキングID単位で動画を切り出してみました。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import time
import argparse

import cv2
import numpy as np
import os

from yolox.yolox_onnx import YoloxONNX
from bytetrack.mc_bytetrack import MultiClassByteTrack


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--movie", type=str, default=None)
    parser.add_argument("--width", help='cap width', type=int, default=960)
    parser.add_argument("--height", help='cap height', type=int, default=540)

    # YOLOX parameters
    parser.add_argument(
        "--yolox_model",
        type=str,
        default='model/yolox_nano.onnx',
    )
    parser.add_argument(
        '--input_shape',
        type=str,
        default="416,416",
        help="Specify an input shape for inference.",
    )
    parser.add_argument(
        '--score_th',
        type=float,
        default=0.3,
        help='Class confidence',
    )
    parser.add_argument(
        '--nms_th',
        type=float,
        default=0.45,
        help='NMS IoU threshold',
    )
    parser.add_argument(
        '--nms_score_th',
        type=float,
        default=0.1,
        help='NMS Score threshold',
    )
    parser.add_argument(
        "--with_p6",
        action="store_true",
        help="Whether your model uses p6 in FPN/PAN.",
    )

    # motpy parameters
    parser.add_argument(
        "--track_thresh",
        type=float,
        default=0.5,
    )
    parser.add_argument(
        "--track_buffer",
        type=int,
        default=30,
    )
    parser.add_argument(
        "--match_thresh",
        type=float,
        default=0.8,
    )
    parser.add_argument(
        "--min_box_area",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--mot20",
        action="store_true",
    )

    args = parser.parse_args()

    return args


class dict_dot_notation(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self


def main():
    # 引数解析 #################################################################
    args = get_args()
    cap_device = args.device
    cap_width = args.width
    cap_height = args.height

    if args.movie is not None:
        cap_device = args.movie

    # YOLOX parameters
    model_path = args.yolox_model
    input_shape = tuple(map(int, args.input_shape.split(',')))
    score_th = args.score_th
    nms_th = args.nms_th
    nms_score_th = args.nms_score_th
    with_p6 = args.with_p6

    # ByteTrack parameters
    track_thresh = args.track_thresh
    track_buffer = args.track_buffer
    match_thresh = args.match_thresh
    min_box_area = args.min_box_area
    mot20 = args.mot20

    # カメラ準備 ###############################################################
    cap = cv2.VideoCapture(cap_device)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, cap_width)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, cap_height)
    cap_fps = cap.get(cv2.CAP_PROP_FPS)

    # 出力ファイル
    if args.movie is not None:
        outpath = os.path.splitext(args.movie)[0] + "_tracked.webm"
    else:
        outpath = "camera.webm"

    writer = cv2.VideoWriter(
        outpath,
        cv2.VideoWriter_fourcc(*'vp80'),
        cap_fps, (cap_width, cap_height)
    )

    # モデルロード #############################################################
    yolox = YoloxONNX(
        model_path=model_path,
        input_shape=input_shape,
        class_score_th=score_th,
        nms_th=nms_th,
        nms_score_th=nms_score_th,
        with_p6=with_p6,
        providers=['CPUExecutionProvider'],
    )

    # ByteTrackerインスタンス生成
    tracker = MultiClassByteTrack(
        fps=cap_fps,
        track_thresh=track_thresh,
        track_buffer=track_buffer,
        match_thresh=match_thresh,
        min_box_area=min_box_area,
        mot20=mot20,
    )

    # トラッキングID保持用変数
    track_id_dict = {}

    # COCOクラスリスト読み込み
    with open('coco_classes.txt', 'rt') as f:
        coco_classes = f.read().rstrip('\n').split('\n')

    # トラッカー別のフレーム保持用
    track_id_frames = {}

    while True:
        start_time = time.time()

        # カメラキャプチャ ################################################
        ret, frame = cap.read()
        if not ret:
            break
        debug_image = copy.deepcopy(frame)

        # 推論実施 ########################################################
        # Object Detection
        bboxes, scores, class_ids = yolox.inference(frame)

        # Multi Object Tracking
        t_ids, t_bboxes, t_scores, t_class_ids = tracker(
            frame,
            bboxes,
            scores,
            class_ids,
        )

        # トラッキングIDと連番の紐付け
        for trakcer_id, bbox in zip(t_ids, bboxes):
            if trakcer_id not in track_id_dict:
                new_id = len(track_id_dict)
                track_id_dict[trakcer_id] = new_id

        elapsed_time = time.time() - start_time

        # デバッグ描画
        debug_image = draw_debug(
            debug_image,
            elapsed_time,
            score_th,
            t_ids,
            t_bboxes,
            t_scores,
            t_class_ids,
            track_id_dict,
            coco_classes,
            track_id_frames
        )

        # 書き込み
        writer.write(debug_image)

    cap.release()
    writer.release()

    # トラッキングID単位の動画
    write_track_id_wise_video(
        track_id_frames, 
        outpath.replace(".webm", ""),
        cap_fps)

def get_id_color(index):
    temp_index = abs(int(index)) * 3
    color = ((37 * temp_index) % 255, (17 * temp_index) % 255,
             (29 * temp_index) % 255)
    return color


def write_track_id_wise_video(
    track_id_frames_dict,
    base_dir, cap_fps
):
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    for track_id, track_frames in track_id_frames_dict.items():
        width, height = 0, 0
        for frame in track_frames:
            height = max(frame.shape[0], height)
            width = max(frame.shape[1], width)

        writer = cv2.VideoWriter(
            f"{base_dir}/{track_id}.webm",
            cv2.VideoWriter_fourcc(*'vp80'),
            cap_fps, (width, height)
        )
        for frame in track_frames:
            pad_left = (width - frame.shape[1])//2
            pad_right = width - frame.shape[1] - pad_left
            pad_top = (height - frame.shape[0])//2
            pad_bottom = height - frame.shape[0] - pad_top
            frame_pad = cv2.copyMakeBorder(frame, 
                            pad_top, pad_bottom, pad_left, pad_right, 
                            cv2.BORDER_CONSTANT, (0, 0, 0))
            writer.write(frame_pad)
        writer.release()


def draw_debug(
    image,
    elapsed_time,
    score_th,
    trakcer_ids,
    bboxes,
    scores,
    class_ids,
    track_id_dict,
    coco_classes,
    tracker_id_frames,
):
    debug_image = copy.deepcopy(image)

    for tracker_id, bbox, score, class_id in zip(trakcer_ids, bboxes, scores,
                                                 class_ids):
        x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])

        if score_th > score:
            continue

        color = get_id_color(int(track_id_dict[tracker_id]))

        # バウンディングボックス
        debug_image = cv2.rectangle(
            debug_image,
            (x1, y1),
            (x2, y2),
            color,
            thickness=2,
        )

        # トラックID、スコア
        score_txt = str(round(score, 2))
        text = 'Track ID:%s(%s)' % (int(track_id_dict[tracker_id]), score_txt)
        debug_image = cv2.putText(
            debug_image,
            text,
            (x1, y1 - 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.7,
            color,
            thickness=2,
        )
        # クラスID
        text = 'Class ID:%s(%s)' % (class_id, coco_classes[class_id])
        debug_image = cv2.putText(
            debug_image,
            text,
            (x1, y1 - 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.7,
            color,
            thickness=2,
        )

        # ID別のフレームを記録
        crop = debug_image[max(y1-30,0):y2, x1:x2, :]
        if not tracker_id in tracker_id_frames.keys():
            tracker_id_frames[tracker_id] = [crop]
        else:
            tracker_id_frames[tracker_id].append(crop)

    # 推論時間
    text = 'Elapsed time:' + '%.0f' % (elapsed_time * 1000)
    text = text + 'ms'
    debug_image = cv2.putText(
        debug_image,
        text,
        (10, 30),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.7,
        (0, 255, 0),
        thickness=2,
    )

    return debug_image


if __name__ == '__main__':
    main()

ID単位での動画の書き出したものがいらなければ「write_track_id_wise_video」などを削除してください(私が追加した部分は適当に書いたコードです)

実行

!python sample.py --movie ../pexels-grace-wu-6022014.mp4 --width 1280 --height 720 --yolox_model model/yolox_s.onnx --input_shape 640,640

このようにバックエンドのYOLOXのモデルも変えられます。

できたもの

※51MBもあるので注意

ぱっと見るといい感じでした。しかし、ID単位のビデオを見ると、「全体の流れを考え、できるだけ少ないIDで統合している」というところに弱さを感じました。おそらくYOLOXの閾値を調整すればある程度カバーできるでしょうが、小さいオブジェクトのトラッキングはやはり物体検出同様難しそうです。ここの求められるさじ加減ユースケースによると思います。

以下の画像は、私が手動で細切れの動画をつなぎ合わせたときのタイムラインです。1秒に満たない小さな動画で細分化されているのがわかります。こういう小さなIDはルールベースで弾いてもいいかもしれません。

とりあえずさっくりトラッキングできたので、当初の目標はクリアできました。とても使いやすいリポジトリだったので高橋かずひとさん感謝です!



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


One Comment

Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です