拙著『深層ニューラルネットワークの高速化』が重版して第 2 刷となりました。皆さまありがとうございます!
もはや恒例、重版に感謝して書き下ろし専門記事をお届けします。
本稿では、SNS などでもたびたび話題になるトランスフォーマーは RNN であるという話をします。本稿では単に形式的に包含性を指摘するだけでなく、トランスフォーマーと RNN はどの程度似ているのかや、そこから導かれる応用上の意味についても詳しくご紹介します。
本稿は『深層ニューラルネットワークの高速化』の第 6.3 節と第 7.2 節に基づいています。
過去回
- 拡散モデルと最適輸送(最適輸送第 5 刷)
- GNN の最新動向(グラフニューラルネットワーク第 3 刷)
- 深層学習で部分空間を扱うときは射影行列を考えるとよい(グラフニューラルネットワーク第 5 刷)
目次
- 過去回
- 目次
- トランスフォーマーは RNN である
- この議論の問題点
- 注意機構の復習
- カーネル法
- カーネル法と注意機構
- トランスフォーマーは RNN である(再)
- 文脈内学習との関係
- 有限次元へ
- 高速化への応用
- トランスフォーマーは RNN である(再々)
- おわりに
トランスフォーマーは RNN である
まず、トランスフォーマーは RNN であることを示します。
結論から言うと、トランスフォーマーはキー・バリューキャッシュを状態とした RNN です。
注意層における過去トークンのキーベクトルとバリューベクトルさえ保存しておけば、過去のトークンが具体的に何であったかや、入力テキストそのもの自体は保存しておかなくても、次トークンの予測が可能です。なので、トランスフォーマーにとってはキー・バリューキャッシュは必要十分な「状態」です。
トークン を入力すると、 これまでのキーベクトルとバリューベクトルのリスト に基づいてトランスフォーマーの規則 が適用されます。これにより次時刻 のトークンが得られ、キー・バリューキャッシュには今時刻で計算したキーベクトルとバリューベクトルが追加されて新しい状態 となり、また次の時刻へと移ります。この計算を行う規則 自体は全ての時刻で共通なので、この計算は「再帰的」です。
この議論の問題点
とはいえ、この議論には少し無理やりな部分があります。
第一は、普通の RNN のようなベクトル形式の状態 とは異なり、リストのペアという構造のある「状態」 を用いている点です。
第二に、普通の RNN のような固定次元のベクトル とは異なり、この「状態」は時刻が進むにつれて大きくなっていきます。
あくまで抽象的な議論をする上ではこれらはあまり問題にはなりませんが、応用上は実装や効率の観点で大きな違いを生みます。
以下では、これらの問題はかなりの程度解消でき、先ほど述べた議論以上にトランスフォーマーは RNN であることを見ていきます。
注意機構の復習
これらの問題に具体的に取り組むために、注意機構について軽く復習を行います。
注意機構への入力は、クエリと呼ばれるベクトル と、キーと呼ばれるベクトルのリスト と、バリューと呼ばれるベクトルのリスト であり、出力は
です。分母は正規化をしているだけなので、一旦無視して、 という注意機構をしばらく考えます。注意機構は、クエリ に対して、内積の意味で似ているキー を探し、そのキーに対応するバリュー に高い重みを付けて足し合わせます。
カーネル法
式 (1) はカーネル法として解釈できます。
ここでは、カーネル法に馴染みのない読者を想定して、カーネル法についての基本を紹介します。
カーネル法とは、カーネル関数をもとにしたパターン認識や機械学習手法の総称です。カーネル関数とは 2 つのデータ を受け取り実数値類似度を返す関数 のことです。
本当は類似度関数がカーネル関数として認められるためには条件(1 変数目と 2 変数目を入れ替えても結果が変わらないことなど)があるのですが、少し複雑なので詳しいことは [赤穂 2008] や [瀬戸+ 2021] などのカーネル法の教科書を参照してください。さしあたり本稿ではカーネル関数は「2 つのデータを受け取り実数値類似度を返す関数である」とだけ認識していれば十分です。
最も有名なカーネル関数はガウスカーネル
です。ガウスカーネルを類似度を測るために使う場面はしばしば見たことがあるかと思います。カーネル法の核心は以下の定理です。
定理 関数 がカーネル関数であるときかつそのときのみ、再生核ヒルベルト空間 と関数 が存在して、任意の について
が成り立つ。
用語と記法がややこしいので整理しましょう。
再生核ヒルベルト空間とは、特定の条件(公理を満たす内積が定まっているなど)を満たすベクトル空間です。しかつめらしいですが、重要なことは
- はベクトル空間であることと、
- このベクトル空間の要素の間には内積 が定義されていること
です。
上記の定理の のことを特徴マップと呼びます。特徴マップはデータを再生核ヒルベルト空間 内の点(ベクトル)に変換します。
上記の定理を言葉で説明すると、
- どんなカーネル関数も、特徴マップでデータを変換してから内積を取るという二段階で計算ができる
- 逆に、特徴マップでデータを変換してから内積を取って測る類似度はカーネル関数である
ということです。
この定理は当然ガウスカーネルにも適用でき、ガウスカーネルは特徴マップ を用いて
と内積形式で書き表すことができます。カーネル法と注意機構
注意機構で出てくる類似度 はカーネル関数であることを示します。
まず、ガウスカーネルを式変形すると、
つまり、特徴マップ より定まるカーネル関数 は注意機構の重み に他なりません。
この結果を用いて式 (1) を
と書き換えると、注意機構は、カーネル関数で重みづけてバリューベクトルを足し合わせていると解釈できます。ここまではただの言い換えに過ぎませんが、重要なのは、内積には線形性があり、和と内積の順序を入れ替えられるということです。このことを利用すると、式 (1) はさらに、
トランスフォーマーは RNN である(再)
この書き換えにより、冒頭で述べた問題点が(ほぼ)解決しました。
トランスフォーマーは という状態を持つ RNN です。
新しいトークンを入力すると、
という規則で出力が計算され、 という規則で状態が更新されます。状態更新の規則は既存の状態にベクトルが足し合わせられるという典型的な RNN の振る舞いであり、出力規則も入力特徴と状態の内積というシンプルなものです。
ここでは簡単のため、単一の注意層のみを考えていますが、注意層が複数積み重なったり、間に MLP など別の構成要素が挟まっても同じ議論が可能です。
文脈内学習との関係
トランスフォーマーの状態 の更新式 (6) を展開すると、状態ベクトルは
というように過去のトークンの影響の和で書き表すことができます。これはつまり、 に似たクエリが来ると に似た出力をするべしということが逐次的に「学習」されていっていることになります。学習の結果、線形モデルの重み が得られます。
これは文脈内学習 (in-context learning) に対応しています。トークンを処理するごとに、こういうキーではこういうバリューであるという入力と出力の対が出来上がっていき、新しい入力ではそれらの対が根拠となって新たな出力が決まっていきます。この流れは通常の学習と同じです。クエリはテストデータ、キーは訓練データの入力、バリューは訓練データの出力に対応しています。ただし、文脈内学習ではモデル自体のパラメータは変更しません。パラメータではなく、内部状態への情報の蓄積という形で学習が起こります。トランスフォーマーはトークンを処理しながらクエリとキーとバリューの関係を「学習」しているということです。特に、プロンプト内に「こういう例ではこういう出力をして」と書いておくと、その対応が「学習」され、その後の出力において、これらの「訓練データ」に基づいて出力が行えるようになります。
モデルパラメータの算術 - ジョイジョイジョイ や『深層ニューラルネットワークの高速化』の第 9.2.4 節でも書きましたが、通常の学習、すなわち重み内学習 (in-weight learning) でも、重みは同様に特徴マップの和で書くことができます。完全性のため、ここでも簡単にご紹介します。
パラメータ初期値を とし、特徴マップを
とするカーネル を考えます。このカーネルをニューラルタンジェントカーネル (neural tangent kernel; NTK) といいます。ニューラルタンジェントカーネルによる類似度が大きいデータ は、勾配が似ているので、モデル更新に与える影響という意味で似ていると解釈できます。確率的勾配降下法 (SGD) により訓練を行い、
と更新していきます。ここで、 は関数出力についての損失の勾配 と学習率の積です(一般性を失うことなく の出力は一次元であると仮定しています。)学習率は小さい正の値であり、 の正負は損失関数によって異なります。その訓練サンプルにとって、出力を大きくするべきであれば は正の値を取り、出力を小さくするべきであれば は負の値を取ります。訓練によりパラメータ が得られたとします。訓練の過程であまりパラメータが動かなかったとし、 の周りの についてのテイラー展開により一次近似すると、パラメータ が表す関数は
また、式 (8) において、
とおくと、訓練後のモデルの出力は となります。式 (7) と式 (10) はとても似ています。重み内学習にしろ、文脈内学習にしろ、学習とは、過去のデータとどの程度似ているかというのをカーネルにより測定し、似ている過去のデータにおいて出力はどうあるべきかをもとに、新たな出力を決定する手続きとみることができます。
有限次元へ
問題はほとんど解決しましたが、これまでの議論にも少し無理やりな部分があります。
ここまで、 と を具体的にどう計算するかは述べてきませんでした。これらはたしかに「ベクトル」ではあるのですが、これらは実は無限次元の「ベクトル」であり、明示的に計算することができません。
明示的に計算できないのであれば応用上は何もかも意味が無くなってしまいそうですが、実はこれはさほど大きな問題ではなく、少し工夫すると応用も可能になります。
第一に、 や は有限次元の特徴マップで精度良く近似できます。色々と方法はありますが、最も簡単な方法はランダム特徴量を利用することです。 というランダムなベクトルをあらかじめサンプリングし、
例えば、 で 200 次元の特徴マップを用いたときのコードと結果は以下の通りです。
厳密な注意重み (= 無限次元の特徴マップ)の計算結果が 1.13 のとき 100 次元の特徴マップによる近似値が 1.35 など、それなりの精度で近似できています。次元を大きくするにつれて近似は正確になります。
よって、以上の議論において (無限次元)が登場した箇所を全て (有限次元)で置き換えると、有限次元で計算可能な RNN となります。
以上はデータの性質を全く仮定せずに構築した一般的かつ簡便な近似法ですが、データの性質が良いときにはより低次元でより高精度に近似することもできます。考え方は主成分分析で次元圧縮をすることと同じです。見かけ上 1000 次元のデータでも、本質的な次元数が小さければほとんど情報を失うことなく主成分分析で 30 次元にすることができます。つまり 1000 次元から 30 次元に圧縮が可能です。再生核ヒルベルト空間も所詮ベクトル空間なので、これとほとんど同じようなことができ、無限次元から 30 次元に圧縮する、というようなことが可能です。これにより、例えばトランスフォーマーを 30 次元の RNN で表現する、といったことも可能です。そのように、データの本質的な低次元性をもとに特徴マップを近似する手法にはナイストローム近似などがあります。
第二に、 や 自体は厳密計算できないものの、これらに固執する理由はありません。注意機構の本質は
というようにカーネルで重みづけてバリューベクトルを足し合わせることであり、カーネルとして を使う必然性はありません。カーネルの中には有限次元の特徴マップで表せるものも数多くあり、そのようなカーネルを用いるとこの問題は無くなります。注意機構においてどのようなカーネルを使うと、精度と効率が両立できるかを考えることは、現在でも盛んに研究されているトピックです。線形トランスフォーマー [Katharopoulos+ ICML 2020] などが成果の代表例です。高速化への応用
以上より、トランスフォーマーには従来的なトランスフォーマーモード式 (1) と、RNN モード式 (5, 6) という、二つの等価な見方ができます。
よく、「トランスフォーマーは RNN よりも並列的な訓練がしやすい」「トランスフォーマーは二乗時間かかり RNN よりも計算量が大きい」などと言われますが、これは見方をどちらか一方に固定したときの話であり、トランスフォーマーの内在的な性質ではありません。同じトランスフォーマーモデルでも、トランスフォーマーモード式 (1) と、RNN モード式 (5, 6) の両方で走らすことができます。訓練時には並列的な訓練がしやすいトランスフォーマーモードで動かし、推論時にはメモリ消費量と計算量の小さい RNN モードで動かす、というように、同一のモデルに対しても柔軟な対応が可能です。
分母の扱い
細かい点ですが、ここまで分母
を無視してきました。この項も、カーネル法の見方を身に付けた今、トランスフォーマーは RNN である(再々)
以上の議論より、トランスフォーマーと RNN は二項対立的なものではなく、むしろ同質のものを異なる角度から眺めたものであると考えることができます。
モデルの定量的な性質としては、状態 の次元が大きな影響を持ちます。ガウスカーネルや、トランスフォーマーでよく用いられる内積の指数を用いると、無限次元のベクトル空間となります。この空間は非常に広大なので、原理的には無限のことを記憶できますが、その分、効率的に計算するのが難しかったり、過学習をしたり、訓練に多くのデータが必要だったりします。他のカーネルを使ったり、主成分分析の要領で近似を行うと、有限次元になり、次元を下げるにつれて計算効率とサンプル効率が向上します。これらについても、二者択一というよりは連続的なスペクトルをなしており、一方の極端に、従来のトランスフォーマーで用いられている内積の指数によるカーネルがあり、そこに高次元の RNN、中程度の次元の RNN、低次元の RNN が連なっているとみることができます。
データの性質によってカーネルや圧縮度合いを選択することができます。例えば、テキストは離散的であり、各トークンの粒度が大きく、少しでもトークンを忘れてしまうと将来の応答で問題が生じる可能性があります。よって、テキストではあまり圧縮は行わず、従来のトランスフォーマーのような無限次元やそれに近い状態空間を用いて、入力をほとんどそのままの形で記憶すると良いと考えられます。一方、動画や音声は連続的であり、一部のフレームをが抜け落ちても問題なく、かなりの程度圧縮ができるので、そのような場合にはより RNN 的な、圧縮率の高い、次元の低い状態空間を用いて、圧縮しながら処理すると良いと考えられます。同じドメインの中でも、分布やタスクの性質によって、どの程度まで圧縮できそうかを考えてこのスペクトルの中の位置を決めることができます。
ニューラルネットワークの損失地形 - Speaker Deck などでもお話しましたが、この「圧縮」と「汎化」と「効率」の関係はモデルの振る舞いを考えるうえでとても重要です。トランスフォーマーを RNN として見ることで、圧縮度合いを状態の次元数という形で陽に表すことができるようになり、この三者関係を意識的に取り扱えるようになるという点でも、この見方を身に付けることは有用です。
おわりに
本稿に興味を持った方はぜひ『深層ニューラルネットワークの高速化』も読んでいただけると嬉しいです。
連絡先: @joisino_ / https://joisino.net