GitHub

Deep Reinforcement Learning-based Image Captioning with Embedding Reward

タグ: Caption Generation Reinforcement Learning

概要

Deep Reinforcement Learning-based Image Captioning with Embedding Reward

  • 従来のキャプション生成はencoder-decoderモデルが多いが、
  • “Policy Network”と”Value Network”を用いた強化学習によるキャプションの生成
  • 性能良い

手法

framework

  • S={w1,w2,,wT} : sentence (文のstep:tにおける生成過程を、状態Sと定義)
  • at=wt+1 (actionは次の単語)

Policy Network,Value Networkは共に1段階目はは別々で学習させて、次に結合させて学習

Policy Network


Policy-Network

pπ(at|st) : 状態Stにおいてどのようなactionをとるべきかを予測

探索の木の幅を絞る役割

  • 1. 画像をCNNpによってencodeし、それをRNNpに入力するための次元に変換して入力
  • 2. RNNp の出力を単語の次元に変換して、方策 pπ(at|st) を求める

1段階目の学習 : 教師ありでcross-entropy lossによって学習させておく

Lp=Tt=1logpπ(atst)

Value Network


Value-Network

vθ(st) : 状態Stの価値を見積もる
探索の木の深さを絞る役割

  • 1. 画像を CNNv によってencodeする
  • 2. RNNv によって現在までのsemanticな情報を出力
  • 3. 上の2つをconcatして MLPv にかけ、状態Stの価値を推定する

1段階目の学習 : まずランダムに選択した生成過程における価値が最終的なReward(文とキャプションの対応度合いを表す)と一致するように学習

||vθ(si)r||2

を最小化

Rewardについて

画像と文章のペアが以下によって同じsemantic embedding spaceに写像されるように学習させたものを用意

  • 画像 : CNNeから抽出した特徴量vをfe(linear mapping layer)にかける
  • 文章 : RNNeの最後の出力hT(S)を出す

以下のロスにて学習

Le=vSmax(0,βfe(v)hT(S)+fe(v)hT(S))+Svmax(0,βhT(S)fe(v)+hTfe(v))

S : 画像に対応していない文のこと
v : 同様

  • 画像とキャプションが対応 -> 内積を大きく
  • 画像とキャプションが対応していない -> 内積を小さく

(交差検証によって決定したmargin βを設ける)

また、Rewardは以下の用に計算

r=fe(v)hT(ˆS)||fe(v)||||hT(ˆS)||

Policy NetworkとValue Networkをつなげて学習


partially observable Markov decesion processと見て、以下の用に勾配を計算

πJ=Tt=1πlogpπ(at|st)(rvθ(st))θJ=θvθ(st)(rvθ(st))

これはactor-critic(pπ:actor,vθ:critic)の関係としても見ることができる。

しかし、やはりactionの候補が単語の種類分あると、選択肢が膨大になりすぎて学習がうまく行かない。

  • そこで、Curriculum Learning の導入

初めの(T-i×Δ)単語はクロスエントロピーで、残ったi×Δ単語を強化学習で (i=1,2,…と増やしていき、徐々に全文を強化学習で生成するようにシフトしていく)

Lookahead inference with policy network and value network


Beam Searchを用いる(従来のキャプション生成にもあったが、幅優先探索を行いつつ、ビーム幅個のノードを保持して評価値の低いものは捨てていくもの)

W[t] : t単語目でのBeam Searchで残っているもの(t単語のかたまりがbeam幅数分) 定式化すると、

W[t]={w1,[t],...,wB,[t]}wb,[t]={wb,1,...,wb,t}

となる。このもとで、次のbeam sequenceは以下の用に定義できる。

W[t+1]=argtopBWb,[t+1]wt+1S(wb,[t+1]),s.t.wi,[t+1]wj,[t+1]

従来のものは、このS(・)は生成された一連のlog確率を表していた。 これは、良いキャプションの中の全ての単語のlog確率は、このtop群に含まれるという仮定に基づくが、必ずしもそれは正しくない。 (ALphaGoも低い確率の行動も選択していた)

そこで、PolicyとValueを組み合わせた、先読みをした推論を以下の用に定式化

S(wb,[t+1])=S(wb,[t])+λlogpπ(at|st)+(1λ)vθ({st,wb,t+1})

wb,[t]のsequenceに対するスコアに、wb,t+1の単語が加わったもののスコアを求めるので、

  • Policyによるwb,t+1の確信度
  • Valueによるwb,t+1が加わったときの評価

の2つをλ:1-λの割合で加えたもの

result sample


result