技術は使ってなんぼ

自分が得たものを誰かの役に立てたい

【Keras】DLでCNNのアップサンプリングでconcatしたいのにTensorのshapeが合わない時の対策

課題

ディープラーニングでモデルを作成してる時に、concatでTensorのサイズが合わない時ってありません?


特に画像データを学習させるようなモデルの場合、代表的なモデルとしてU-Net等が挙げられます。


例えばこんなエラーメッセージ。

A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 22, 42, 4096), (None, 21, 41, 4096)]


テンソルの形状が合ってないから結合できねぇよ、ってエラーです。


このエラーが出る度に心の中で舌打ちしながら、ゼロパディングやフィルターサイズを調整してます。


ゼロパディング?フィルターサイズ?なにそれおいしいの?って方は以下のリンクを参照ください。
deepage.net

この調整作業、そろそろ本気でめんどくさくなってきたので、shapeの差分をゼロパディングで自動計算して埋められる関数作りました。


ネットに落ちてねぇかなぁと思ったのですが、意外とないんですよね。。


なので今後も自分が使うためと、他の開発者のリソース削減に役立てればと思い作成しました。


kerasかTensorflowを自動でパディング調整してくれるように仕様変更してくんねぇかなぁ。。


まぁでもそんなことしたら考えない開発者を増やしてしまうのか・・?

対策

def adj_concat(base, target):
    base_h = base._keras_shape[1]
    base_w = base._keras_shape[2]
    target_h = target._keras_shape[1]
    target_w = target._keras_shape[2]
    
    # shapeの上下方向と左右方向の差分を確認
    diff_h = base_h - target_h
    diff_w = base_w - target_w
    
    # 上下方向
    if diff_h != 0:
        pad_h = abs(diff_h)
        # 差分が偶数→上下ゼロパディング
        if (diff_h % 2) == 0:
            # baseの方が大きい→targetを調整
            if diff_h > 0:
                target = L.ZeroPadding2D(padding=(int(pad_h/2), 0))(target)
            # targetの方が大きい→baseを調整
            else:
                base = L.ZeroPadding2D(padding=(int(pad_h/2), 0))(base)
        # 差分が奇数→上側だけゼロパディング
        else:
            # baseの方が大きい→targetを調整
            if diff_h > 0:
                for i in range(pad_h):
                    if i % 2:
                        target = L.ZeroPadding2D(padding=((1, 0), (0, 0)))(target)
                    else:
                        target = L.ZeroPadding2D(padding=((0, 1), (0, 0)))(target)
            # targetの方が大きい→baseを調整
            else:
                for i in range(pad_h):
                    if i % 2:
                        base = L.ZeroPadding2D(padding=((1, 0), (0, 0)))(base)
                    else:
                        base = L.ZeroPadding2D(padding=((0, 1), (0, 0)))(base)
    # 左右方向
    if diff_w != 0
        pad_w = abs(diff_w)
        # 差分が偶数→左右ゼロパディング
        if (diff_w % 2) == 0:
            # baseの方が大きい→targetを調整
            if diff_w > 0:
                target = L.ZeroPadding2D(padding=(0, int(pad_w/2)))(target)
            # targetの方が大きい→baseを調整
            else:
                base = L.ZeroPadding2D(padding=(0, int(pad_w/2)))(base)
        # 差分が奇数→左側だけゼロパディング
        else:
            # baseの方が大きい→targetを調整
            if diff_w > 0:
                for i in range(pad_w):
                    if i % 2:
                        target = L.ZeroPadding2D(padding=((0, 0), (1, 0)))(target)
                    else:
                        target = L.ZeroPadding2D(padding=((0, 0), (0, 1)))(target)
            # targetの方が大きい→baseを調整
            else:
                for i in range(pad_w):
                    if i % 2:
                        base = L.ZeroPadding2D(padding=((0, 0), (1, 0)))(base)
                    else:
                        base = L.ZeroPadding2D(padding=((0, 0), (0, 1)))(base)
                
    return base, target

解説

引数のbaseとtargetはTensorオブジェクト(keras.layer)を想定しています。


肝となる考え方を中心に解説します。


上記エラーメッセージのように、concatターゲットとなる[(None, 22, 42, 4096), (None, 21, 41, 4096)]のような、二つのTensorがあります。


この2つのTensorのshapeを一致させたいので、両者がどれだけ離れているかの差分が必要となります。


この場合、22と21、42と41で比較し、少ない方を1増やすパディングが必要と考えます。


「え?1だけ増やすパディングって、そんなんできたっけ?」と思ったあなた。


あなたの脳内パディングのイメージは正しいと思います。


普通ゼロパディングというと、画像データの周りを0で囲むようなイメージかと思います。


なので、よくある書き方として、

outputs = L.ZeroPadding2D(padding=1)(inputs)

これはinputsの上下左右に1列1行ずつ追加して周りを0で囲む処理になります。


引数を以下のように変えると、片側だけに0で埋めることもできます。


それがソースコード内の

target = L.ZeroPadding2D(padding=((pad_h, 0), (0, 0)))(target)

今回の関数では、上下方向に1つだけ増やしたい場合は上側、左右方向に1つだけ増やしたい場合は左側にパディングしてます。


paddingには2種類のタプルを渡すことができます。
①(symmetric_height_pad, symmetric_width_pad)のタプルを渡せば、上下と左右で指定ができ、


②( (top_pad, bottom_pad), (left_pad, right_pad) )のタプルを渡せば、上下左右のいずれかで指定ができます。


参考にkeras公式のリンク張っときます。ZeroPadding2Dのところに書いてます。
keras.io


このタプルパターン2種類を全部使い切らないと、本題の課題は解決できません。


なぜなら、shapeの差分は偶数の場合と奇数の場合があり、偶数の場合は①、奇数の場合は②となります。


ソースコードで偶数とか奇数とか書いてるのは、shape差分に応じてゼロパディングを変化させるためです。


あとは上下方向の差分(diff_h)と左右方向の差分(diff_w)が!=0(一致しない)時に実行するような仕様です。


ソースコードの先頭で、shapeの情報を得るのに、「_keras_shape」を使ってるのに違和感を感じた方はいるかもしれません。


「あれ、Tensorってshapeメソッドで取れなかったっけ?」


取れます。実際keras layer でもConv2Dとか普通に取得できます。


が、keras layerで取れないオブジェクトが存在します。私も動かしてから気づきました。。


それが、Conv2DTranspose(Deconv2D)です。こいつにshapeメソッドつけると(?, ?, ?, 4096)とかってなります。


keras公式でも放置されてるバグらしいです。


現在はTensorflow内のkerasという位置づけですが、最新のTensorflowでは問題なくshapeメソッドで取得できるらしいです。


ただ私のように、Tensorflow2.0以前のバージョンで開発してる方からすると、残念ながら自前での対応が必要になるかと思います。


そういう意味でも、この関数が何かのお役に立てればいいなと思います。


もし使ってみたよーって方や、こんな方法もあるで?みたいなご意見やご質問・ご感想等あればコメント欄にお願いします!