ジョイジョイジョイ

ジョイジョイジョイジョイジョイ

Generative Adversarial Nets

Generative Adversarial Nets [1] を chainer で実装しました。

いわゆる GAN です。

最近いろいろ派生系が出ています。画像の生成モデルはほとんどこれの派生な気がします。

画像を生成する Generator (以下 G)と、画像が本物か G が生成したものかを識別する Discriminator (以下 D) という 2 つのモデルを同時に学習させていきます。

論文に書いてあるように、G は偽札製造業者、 D は警察という類推がわかりやすいです。

業者は警察にバレないように紙幣にできるだけ似せたものの作り方を学ぶ一方で、警察はより精巧な偽札を本物と区別できるように学習する、ということを繰り返すうちに、互いの技術は向上していきます。

これと同様に、G は生成するデータと教師データが D に区別できないようなパラメータを学習する一方で、 D は本物のデータか G が生成したデータかを正しく識別するようパラメータを学習していきます。

具体的には、以下の式で最適化します。

 \displaystyle min_{G} max_{D} V(D,G) = \mathbb{E}_{x \sim p_{data}(x)}[log D(x)] + \mathbb{E}_{z \sim p_{x}(z)}[log(1-D(G(z)))]

 p_x(z) は G に食わせるデータです。一様分布から生成されるノイズなどを使う場合が多いようです。

教師データを正例、G の生成したデータを負例としたときの交差エントロピー誤差をお互い最適化するということです。

論文に証明が書いてある通り、この関数の大域最適解において、G が生成するデータの分布と、教師データの分布は一致します。

あとはこれを逐次的に最適化していけばよいです。

今回の実装では、教師データに MNIST を使いました。

ソースコードは以下の通りです。

gist.github.com

生成した画像は以下の通りです。(学習の最後に出力された 10 枚を掲載します)

f:id:joisino:20170703180931p:plainf:id:joisino:20170703180935p:plainf:id:joisino:20170703180939p:plainf:id:joisino:20170703180943p:plainf:id:joisino:20170703180947p:plainf:id:joisino:20170703180950p:plainf:id:joisino:20170703180954p:plainf:id:joisino:20170703180957p:plainf:id:joisino:20170703181000p:plainf:id:joisino:20170703181006p:plain

周りにもやもやがついているのは残念ですが数字であることはわかります。

今回は G は線形 + relu + sigmoid で D は線形 + relu + batch normalization + sigmoid です。

G に batch normalization した方がよいという記事をネット上でいくつか見かけたのですが、実験してみたところ微妙でした。

以下のような感じです。

G: bn あり D: bn あり

f:id:joisino:20170703181429p:plain

G: bn なし D: bn なし

f:id:joisino:20170703181604p:plain

G: bn あり D: bn なし

f:id:joisino:20170703181537p:plain

G の bn を有効にすると数字がぼやけてしまい、 D の bn を有効にすると周辺のもやは取れるのですが学習がうまく進みませんでした。

これらの結果はハイパーパラメータの値のチューニング次第かもしれません。

今回は linear でやりましたが deep convolutional なものもまた実装したいです。

参考文献

[1] Goodfellow, I. J., et.al. (2014). Generative Adversarial Nets