技術は使ってなんぼ

自分が得たものを誰かの役に立てたい

【続・SSD】CutMixを複数合成できるようにして物体検出したら更に精度があがった話

背景

前回の記事の続きです。
yonesuke0716.hatenablog.com


前回の記事中にあった、

『まぁその辺はbboxのサイズが閾値以上に限定したり、複数のbboxで行う等の改良をすれば解決するかと思いますが、一旦上記の方法で実装して効果を検証します。』


これをやってみたら、想定以上に効果があったので、記事にしました。

ソースコード

img, bbox, label = in_data
# 0. CutMix
if np.random.randint(1,7) % 2 == 0:
    idx = random.randint(0, len(train_dataset)-1)
    cut_img, bbox_list, label_list =  train_dataset[idx]
    idx_list = list(range(len(bbox_list)))
    random_idx = random.sample(idx_list, int(len(bbox_list)/2))

    for i in random_idx:
        cut_bbox = bbox_list[i]
        cut_label = label_list[i]
        cut_img = cut_img.astype(np.int32)
        cut_bbox = cut_bbox.astype(np.int32)
        cut_label = cut_label.astype(np.int32)

        def cutmix(img_1, img_2, bbox_1, bbox_2, label_1, label_2):
            bx1, by1, bx2, by2 = bbox_2
            img_1[:, bx1:bx2, by1:by2] = img_2[:, bx1:bx2, by1:by2]
            new_label = np.append(label_1, label_2)
            new_bbox = np.append(bbox_1, [list(bbox_2)], axis=0)
            return img_1, new_bbox, new_label

        img, bbox, label = cutmix(img, cut_img, bbox, cut_bbox, label, cut_label)

解説は前回のソースコードとの差分だけ解説します。

解説

img, bbox, label = in_data
# 0. CutMix
if np.random.randint(1,7) % 2 == 0:
    idx = random.randint(0, len(train_dataset)-1)
    cut_img, bbox_list, label_list =  train_dataset[idx]
    idx_list = list(range(len(bbox_list)))
    random_idx = random.sample(idx_list, int(len(bbox_list)/2))

    for i in random_idx:
        cut_bbox = bbox_list[i]
        cut_label = label_list[i]
        cut_img = cut_img.astype(np.int32)
        cut_bbox = cut_bbox.astype(np.int32)
        cut_label = cut_label.astype(np.int32)

前回は複数のbboxのひとつを合成するだけでしたが、bbox総数の半数(len(bbox_list)/2)をターゲットにしているところがポイントです。

後はfor文でひとつずつ実行し、appendしてやるという寸法です。

結果

評価結果比較


なんと1~1.5%程度も精度が向上しています!
わずか10epochの学習でこの改善は大きい。前回が0.2%であったことを考えると、5倍以上の効果を発揮しています。


ただしこのデータセット、よくみてみるとかなり偏りがあることがわかります。


まぁ白血球が多いのは人がそういうものなので仕方のないことですが。。


血小板や赤血球の精度が向上した理由として、もともと少ない血小板や赤血球のデータが増えたことも要因として考えられます。


Augmentationのもともとの狙いはデータ拡張ですので、データセットにバリエーションをもたせる狙いがあります。


今回のBCDDというデータセット364個のデータ数と非常に小さなデータセットです。


なので、Augmentationの効果が如実に出る可能性が高いデータセットであるともいえます。


とはいえ人の細胞の拡大写真はオクルージョンが多いので、CutMixが効果を発揮したのも直感的に理解できます。


このように、データセットに合わせてAugmentationも合わせて行うと効果を発揮するということが、本事例で感覚的にもご理解いただけたかと思います。