Sparse and Constrained Attention for Neural Machine Translation を読んでわかったところをまとめた。
これを使うとseq2seqで生成したシーケンスに繰り返しが少なくなるという目論見がある。
前提知識
-
集合の記号がわかっていること
-
arg min f(x)の意味 がわかる
-
n次元ベクトル の定義を知っている
-
sparsemaxでsparseな活性関数 を見ておく
csparsemaxの定義
- 確率単体を定義する
$\begin{aligned} \Delta^J &:= \left\{ \alpha \in \mathbb{R}^J | \alpha \geq 0, \Sigma_j \alpha_j = 1 \right\} \end{aligned}$
- 式本体の定義
${{\bf csparsemax}({\bf z;u}) := \underset{ \alpha \in \Delta^{J} }{argmin} \|\alpha-z\|^2 }$
- これはsparsemaxととても似ている
- とは言っても、似ているのは定義だけで実装は全く異なる
${{\bf sparsemax}({\bf z}) := \underset{ p \in \Delta^{K-1} }{argmin} \|p-z\|^2 }$
$\alpha^*$を以下のように定義する
$\begin{aligned} \alpha^* &= csparsemax(z;u) \\ sets A &= \{j \in [J] \ | \ 0 < \alpha^*_j < u_j\}, \\ A_L &= \{j \in [J] \ | \ \alpha^*_j = 0\}, and \\ A_R &= \{j \in [J] \ | \ \alpha^*_j = u_j\} \end{aligned}$
今気づいたが$u$の定義がないじゃん…どうなってるの? → 下記で解決
$u$ (upper bounds)の定義
- これについては、論文には間接的にしか説明がなく、以下の動画で詳細がわかった
- Sparse and Constrained Attention for Neural Machine Translation
- ざっくり言うと、"fertilities"という仕組みを使って$u$を計算する
- 単語の位置(t = timestep)に従って $\alpha_\tau$ の総和を求める, $\alpha$はもちろん上記で求めたもの
式は以下のようになる
$\begin{aligned} & \alpha_t = {\bf csparsemax}(z_t\ ;\ \underbrace{f - \beta_{t-1} }_{u_t}) \\ & \beta_{t-1} := \Sigma_{\tau=1}^{t-1} \alpha_{\tau} \end{aligned}$
- $\beta_{t-1} := \Sigma_{\tau=1}^{t-1} \alpha_{\tau}$ の部分について
- これは論文の方で各ソースワードがこれまでに受けた累積的なattentionと書かれている通り累積なのでどんどん加算されるわけだが、f (=fertility) が定数で存在しているので、1回使うごとに確率が低くなる(=繰り返しが防げるというわけだ)
- ここでの注意としては、条件として$\alpha \ge 0\ \text{and}\ 0 < \alpha^*_j < u_j$ なので、少なくともuは0以上ということだ、uが0以下になったら0に補正する
というわけで、まとめると
- upperbounds $u$ はattentionを計算するたびに算出される
- attention自体については Attention is all you need (Vaswani, 2017) を確認するべし
- 詳細は実装を確認した https://github.com/Unbabel/OpenNMT-py/blob/dev/onmt/modules/constrained_sparsemax.py
どうやって$\alpha$を求めるのか
ここからはアルゴリズムに注釈を入れながら進める
初期化部分
-
入力値
- $a, b, c, \mathcal{W}$ ともにベクター(配列)、$d$はスカラー値
-
入力値の詳細は a.3 Linear-Time Evaluation によれば
$$\begin{aligned} a_j &= \frac{- z_j}{2} \\ b_j &= \frac{u_j-z_j}{2} \\ c_j &= 1 \\ d &= \frac{1-\Sigma_{j=1}^J z_j}{2} \\ \end{aligned}$$ -
$J$は入力の長さ
-
$\mathcal{P} \leftarrow \{ a_j, b_j \}^J_{j=1} \cup \{ \pm \infty \}$ ←ここの意味にちょっと悩んだが、$a, b$のベクターに±∞を入れておくという意味だと思う。競技プログラミング的ですね。だいたい番兵的に使われるやつ。
${\color{blue} \begin{aligned} & {\bf input:} a,b,c,d \\ & Initialize\ working\ set\ \mathcal{W} \leftarrow \{ 1, \cdots, J \} \\ & Initialize\ set\ of\ split\ points: \\ & \mathcal{P} \leftarrow \{ a_j, b_j \}^J_{j=1} \cup \{ \pm \infty \} \end{aligned} }$
ループ前の初期化
- $\xi$はクスィー
${\color{blue} \begin{aligned} & Initialize\ \tau_L \leftarrow -\infty, \tau_R \leftarrow \infty, s_{tight} \leftarrow 0, \xi \leftarrow 0. \end{aligned} }$
ループ
-
ここから怒涛の∑祭りになる
- なんか添字の間違いがありそう、勝手に直した
-
$Median$ってなんだ?
- どうやら集合の中央値を求めているようだ、以下の論文で提案されている手法
- Two papers on the selection problem: Time Bounds for Selection and Expected Time Bounds for Selection
- 中央値の中央値 (median of medians)
-
"Reduce set of split points" の部分
- $\mathcal{P} \leftarrow \mathcal{P}\ \cap [\tau_L, \tau_R]$ ← この角カッコは閉区間なので、実質以下の意味
- $\mathcal{P} \leftarrow \mathcal{P}\ \cap \{ x \in \mathbb{R}\ |\ \tau_L \leq x \leq \tau_R \}$ ハマった…
${\color{blue} \begin{aligned} & {\bf while}\ \mathcal{W} \neq \varnothing\ {\bf do} \\ & \qquad {\sf Compute}\ \tau \leftarrow Median(\mathcal{P}) \\ & \qquad {\sf Set}\ s \leftarrow s_{tight} + \Sigma_{j \in \mathcal{W}\ |\ b_j<\tau} c_j b_j + \Sigma_{j \in \mathcal{W}\ |\ a_j>\tau} c_j a_j + (\xi + \Sigma_{j \in \mathcal{W}\ |\ a_j \leq \tau \leq b_j} c_j ) \tau \\ & \qquad {\sf If}\ s \leq d,\ {\sf set}\ \tau_L \leftarrow \tau; {\sf If}\ s \geq d,\ {\sf set}\ \tau_R \leftarrow \tau \\ & \qquad {\sf Reduce\ set\ of\ split\ points:}\ \mathcal{P} \leftarrow \mathcal{P}\ \cap [\tau_L, \tau_R] \\ & \qquad {\sf Update\ tight\ sum:\ } s_{tight} \leftarrow s_{tight} + \Sigma_{j \in \mathcal{W}\ |\ b_j \lt \tau_L} c_j b_j + \Sigma_{j \in \mathcal{W}\ | a_j \gt \tau_R} c_j a_j \\ & \qquad {\sf Update\ slack\ sum:\ } \xi \leftarrow \xi + \Sigma_{j \in \mathcal{W}\ |\ a_j \leq \tau_L \land b_j \geq \tau_R} c_j \\ & \qquad {\sf Update\ working\ set:\ } \mathcal{W} \leftarrow \{ j \in \mathcal{W}\ |\ \tau_L \lt a_j \lt \tau_R\ \lor\ \tau_L \lt b_j \lt \tau_R \} \\ & {\bf end\ while} \end{aligned} }$
ループ終了後
${\color{blue} \begin{aligned} & {\bf Define\ } y^* \leftarrow (d - s_{tight})\ /\ \xi \\ & {\bf Set\ } x^*_j = max\{a_j, min\{b_j, y\}\}, \forall_j \in [J] \\ & {\bf output\ } x^*. \\ \end{aligned} }$
終わった、長かった。