Galapagos Tech Blog

株式会社ガラパゴスのメンバーによる技術ブログです。

TensorFlowで単純なseq2seqモデルとattention seq2seqモデルを比較してみた

こんにちは、
11月8日にAWSxBot勉強会*1で無事に発表してきた@vanhuyzです。 発表資料も上げましたので、来られなかった方は是非チェックしてください!

さて、今回の内容です。 最近、Google翻訳先生がすごくなるよという話題がありますね。 手法はこの論文*2に詳しく書いてありますが、基本的にSequence to Sequenceモデル (seq2seq) を使っているそうです。

背景

まず、seq2seqモデルを少しまとめてみます。

単純なseq2seqモデル

f:id:glpgsinc:20161128113836p:plain

seq2seqモデルは名前の通り、入力シーケンスから出力シーケンスに変換するモデルです。シーケンスはテキストでも良し、画像や音声でも構いません。そのため、seq2seqは機械翻訳、会話生成、画像キャプション、音声認識などに広く応用されています。

上の図にはABCというシーケンスからWXYZというシーケンスに変換しますね。 基本的にseq2seqは2つのRNNから構成されています。1つ目のRNNはencoderと呼ばれていて、入力シーケンスから固定サイズのベクトルに変換します。次に、その固定サイズベクトルを2目のRNN(decoderと呼ばれる)に入れて、出力シーケンスを生成します。

Attention Seq2seqモデル

f:id:glpgsinc:20161128113852p:plain

単純なseq2seqモデルは固定サイズベクトルを使用するため、複雑なタスクに対して表現できないことがあります。そのため、入力シーケンスの各ステップの出力(または状態)の情報を利用すると、より良い出力シーケンスを生成できるという考えです。これは注目(attention)メカニズムと呼ばれています。

今回は単純なseq2seqモデルとattention seq2seqモデルを比較しようと思います。

実験

問題設定

今回は機械翻訳タスクを検証してみたいです。 実際の言語を試したいですが、今回はデモという目的なので、1つの言語を作ります。 例えば、ある国の言語は数字だけで表記します。Numberishを呼びましょう。

この国の言葉は英語から翻訳すると、こういうルールとなります:

i → 1
you → 3
english → 7
Pneumonoultramicroscopicsilicovolcanoconiosis → 45

見たら分かる通り単純に文字数です。

次は文法です。文法は英語とちょっと逆で、例えば

i love you (→ love i you) → 413
this is a long sentence (→ is this a sentence long) → 24184

となっています。大体って言えば3,6,9…番目の単語をそのままにして、残りはとなりの単語と取り替えるということです。もし最終の単語が取替相手がない場合、そのままにします。コードで表すと、

def grammar(length):
  mygrammar = [1, 0, 2]
  if length <= 0:
    raise ValueError('Length should be >= 1') 
  if length == 1:
    return [0]
  if length == 2:
    return [1,0]
  for i in range(3,length):
    if length % 3 == 1 and i == length - 1:
      next_pos = length - 1
    else:
      next_pos = mygrammar[i-3] + 3
    mygrammar.append(next_pos)
  return mygrammar

となります。

>>> print(grammar(1))
[0]
>>> print(grammar(2))
[1, 0]
>>> print(grammar(10))
[1, 0, 2, 4, 3, 5, 7, 6, 8, 9]
>>> print(grammar(11))
[1, 0, 2, 4, 3, 5, 7, 6, 8, 10, 9]

Numberishの特徴:

  • 曖昧さがある
  • 単語間の区切りなし(日本語と同じ)
  • もともとの英語文の長さと違うことがある
    • 例えば: neural machine translation(長さ3)→ 7611(長さ4)

こういう特徴で本来の機械翻訳タスクの難しさとちょうど良いと思います。

実装

教師データ生成

教師データは英文とNumberishに訳した文のペアが必要です。 例えば、「i have a pen」→「4113」というペアです。

Numberishの先生はルールがわかりましたので、大量のデータを作ることができます。 データだけモデルに与えて、もちろん文法のルールを一切教えないのです。

def encode(text):
  """ Numberize a sequence """
  words = text.split()
  new_text = ''
  for i in grammar(len(words)):
    new_text += str(len(words[i]))
  return new_text
>>> test = 'i have a pen'
>>> print(encode(test))
4113

学習モデルの定義

単純なseq2seqモデルとattention seq2seqモデルはTensorFlowが提供するのでそれらを使います。

  • 単純なseq2seq:tf.nn.seq2seq.embedding_rnn_seq2seq
  • Attention seq2seq:tf.nn.seq2seq.embedding_attention_seq2seq

ソースコードGitHubに上げましたので、興味ある方は是非チェックしてください。 github.com

結果

まずはloss値です。

f:id:glpgsinc:20161124211010p:plain

f:id:glpgsinc:20161124211023p:plain

グラフを見たらわかるように、単純seq2seqのlossは0.3に収束しますが、attention seq2seqは0.2までいけました。 0.1の差は小さいように見えますが、実はかなり大きいです。 この結果を見たら、機械翻訳タスクに対しては確かにattention seq2seqの方が優勝だとわかりました。

では、実際のテスト文を入れてみましょう。この文は学習データに含まらないのです。

Input                   :  neural machine translation by jointly learning to align and translate
Correct output          :  76117285239
Simple seq2seq output   :  126112853298
Attention seq2seq output:  76117285236

この例を見たら、単純seq2seqは明らかにダメのに対して、attention seq2seqは90%当たりましたね。最後のtranslateは9文字なのに6に訳してしまうのがちょっと残念です。

Attention Seq2seqの面白いことは可視化することができます。以下はattention matrixの可視化です。attention matrixは元の英文と訳した文の間、どういう関係があるのかを表す行列です。 x軸は元の英文で、y軸はNumberishに訳した文です。英語の単語がNumberishのどの単語に相当するのかモデルがよく判断できましたね。 例えば、「neural」→「6」、「translation」→ 「11」となっています。

文法が違って、単語の位置がずれてもちゃんと訳できるのですごく感動しました。

f:id:glpgsinc:20161124193902p:plain

ちなみに、attention matrixの値はTensorFlowのAPIから取れず、元の論文*3を参考しながらTensorFlowソースコードをいじりました。

最後に

株式会社ガラパゴスでは、新しい技術が好きな方を絶賛大募集中です!

RECRUIT | 株式会社ガラパゴス iPhone/iPad/Androidのスマートフォンアプリ開発

*1:AWS×BOT】TechTalk #3 https://lig.connpass.com/event/41826/

*2:Neural Machine Translation by Jointly Learning to Align and Translate https://arxiv.org/abs/1409.0473

*3:Grammar as a Foreign Language https://arxiv.org/abs/1412.7449