読者です 読者をやめる 読者になる 読者になる

にほんごのれんしゅう

日本語として伝えるための訓練を兼ねたテクログ

alternative illustration2vec(高次元タグ予想器)について

alternative illustration2vec(高次元タグ予想器)について

f:id:catindog:20170311231233p:plain

図1. 予想結果のサンプル

はじめに

今回はillustration2vecを去年10月に知り、実装法を模索していたが、Kerasでの転移学習と、目的関数を調整することで同様の結果が得られるのではないかという仮説に基づいて、検証実験を行った。

illustration2vecのような画像のベクトル化技術に関してはアプローチは複数用意されており、どのような方法がデファクトかつ、もっとも精度が良いのかわかっていない。

以下、私が考えた3つの方法を記す。

  • 1. VGG16などの学習済みモデルの出力部分のみを独自ネットワークの入力にすることで、タグ予想問題に切り替える
  • 2. 上記のアプローチをとるが、入力に途中のネットワークのレイヤのベクトルも入力に加える
  • 3. キャラクタ判別問題などにタスクを切り替えて、タスク完了後ネットワークを学習させないようにフリーズして、途中のレイヤの出力、ボトルネック層の出力などを素性にしたlogistic regressionにて学習する

f:id:catindog:20170311231843p:plain

図2. 今回用いた転移学習の例

1,2は転移学習とよばれる学習済みのモデルを利用して再度学習する方法であるが、ネットワークの出力に近い部分を取り外して任意のネットワークに置換することで別の問題をとくことが可能になる。
この時、よく調整された評価用のVGG16やResNetなどは卓越した特徴量抽出機能を備えている。下手に独自の学習で悪影響を及ぼすより、フリーズと呼ばれる学習を反映しない状態に変換して特徴量抽出器のような使い方をすることができる。

タグの確率表現としての妥当性や多数でのパラメータでの多様性を考えると3が最も良いように見える。これはディープラーニングと既存の実績のある判別問題のアプローチを合成したハイブリットなアプローチである。

先行研究

オリジナルペーパ、やってる概要はつかめたのだが、学習用の実装が見当たらず、妄想であれこれいけるんじゃないかと検証していた

タグ予想という名前でプロダクトを出しており、githubで中間層のデータを参照するようなコードの断片を見つけて、自分の考えがそんなに間違いでないことを知る
なお、400次元に出力を限定しているがGPUのメモリを考えれば5000次元ぐらい行けるのになぜ?

実験

  • GPUの性能の関係でVGG16を改造することに決定
  • 1,2のアプローチの後、性能が出なければ、3のアプローチを導入(1がうまく行ったため行っていない)
  • http://danbooru.donmai.us/ というサイトから、scraperを回して画像とタグ情報を139万枚取得。評価者に馴染み深いドメインである必要があったため、艦これのタグが入っているものを優先。
  • オプショナルな情報として、アダルトスコアも習得
  • 学習時間の関係から5万枚を学習に使用した
  • VGG16を用いるため、150*150*3の画像が入力となる
  • 出力次元には特に制限を設ける必要性を感じなかったために4096次元にした
  • 多くの例ではVGG16の15層までをフリーズするが、性能試験を行ったところ、12層ぐらいのフリーズの方が今回のタスクでは性能が向上した
  • Dropoutは性能が改悪することがあったため、BN(BatchNormalization)で対応した
  • Activation関数に関しては、linearで出力した後Sigmoidをかけている(意味ない可能性ある)
  • 評価関数はbinary_crossentropyを用いている
  • optimizerはadam
  • 評価いはニコニコ静画の画像をもって行うものとする(学習データに含まれていない必要があるため、最新の投稿を用いる)

ネットワーク

  • (2に関して結果から言うと、どうしてもロス率の下がりが悪く、計算リソースの関係で諦めた)
  • 1のネットワークに関しての図は、図2を用いた。
  • Kerasでのモデル定義のコードを記す
def build_model():
  input_tensor = Input(shape=(150, 150, 3))
  vgg16_model = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
  dense  = Flatten()( \
             Dense(2048, activation='relu')( \
               BN()( \
                 vgg16_model.layers[-1].output ) ) )
  result = Activation('sigmoid')(\
             Activation('linear')( \
               Dense(4096)(\
                 dense) ) )
  model = Model(input=vgg16_model.input, output=result)
  for i in range(len(model.layers)):
    print(i, model.layers[i])
  for layer in model.layers[:12]: # default 15
    layer.trainable = False
  model.compile(loss='binary_crossentropy', optimizer='adam')
  return model

結果

良い結果になった。
以下に三つの例を載せる。
(画像の著作権ニコニコ静画及び絵を書いた人に帰属します)

f:id:catindog:20170311232415p:plain

図3.複数人数を把握することも可能

f:id:catindog:20170311232526p:plain

図4.楽器などの特殊な状況に対応できるか不安だったが正しく認識しているようだ

f:id:catindog:20170311232631p:plain

図5.トップに載せいている電ちゃんと同一(電ちゃんかわいい)

コード

今回は試行錯誤がたくさんあり、コードに微調整の後が大量に見て取れるかもしれない。赦してほしい。
github.com
git cloneする。

git clone https://github.com/GINK03/alt-i2v

画像のデータが必要な方はスクレイピング
(通信速度と頻度に関しては十分注意してください)
bautfulsoupとよばれるモジュールが必要です。

python3 danbooru_datasetgenerator.py --mode=scrape

ダウンロードが完了したら、タグ情報の前処理

python3 alt_i2v.py --maeshori 

高速に画像を処理するために、画像をベクトル化してKVSに格納する
numpy, pillow, msgpack, msgpack-numpy, plyvel+leveldbが必要です

python3 alt_i2v.py --build

トレイン

python3 alt_i2v.py --train

タグ情報を予想してみる

python3 alt_i2v.py --pred foojpg bar.jpg

結論

  • タグの予想は意外と簡単にできるし、精度も悪くない。
  • 今回作成したネットワークを用いることで類似画像検索や、その画像が何を示しているのかを概念的に把握することができる。これはSingle Shot DetectionやYolo V2とは異なったアプローチで状況の把握が可能になることを示している
  • 別に400次元に出力を限定する意味を感じなかったが正しかったようだ。RNNやってれば10000次元に到達することなんてザラだし。

謝辞

- 小林さんちのメイドラゴンのED 「イシュカン・コミュニケーション
「なんでルールはきゅうくつ?胸がしまっちゃうね、下等で愚かな価値観。だめって決定する、倫理なんていりませんよ」の流れ最高ですね。