2020/10/08 08:36 更新
sparsemaxでsparseな活性関数
129 いいね ブックマーク
目次

From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification
に載っている数式がわからなかったのでまとめた。

前提知識

sparsemaxの定義

  • n次元ベクトル$z$の次元数を$K$とする(つまりただの長さ$K$の配列)

  • $\Delta^{K-1} := \left\{ p \in \mathbb{R}^K | 1^{\top} p = 1, p \geq 0 \right\}$

    • $z$の1行を足し込むと1.0になるようにするということ?
    • ちょっとよくわからなかった、誰か教えてください

としたとき、sparsemaxの定義は以下になる

${{\bf sparsemax}({\bf z}) := \underset{ p \in \Delta^{K-1} }{argmin} \|p-z\|^2 }$

定義の式の意味

  • arg min f(x)の意味 で書いたように、これは $\|p-z\|^2$ を最小にするような$p$の集合(というかベクトル)を求めるということを意味する。
  • $\|p-z\|^2$ のようなパイプが2本あるやつは絶対値ではなくユークリッドノルム($L^2$ノルム)を意味するらしく、展開すると以下のようになるはず
$$\begin{aligned} \|p-z\|^2 &= (\sqrt{(p_1-z_1)^2+(p_2-z_2)^2+\cdots +(p_n-z_n)^2})^2 \\ &= (p_1-z_1)^2+(p_2-z_2)^2+\cdots +(p_n-z_n)^2 \end{aligned}$$
  • こういうことをやるのをEuclidean projection(ユークリッド射影)というらしいが、資料が見つからない
    • ユークリッド射影じゃない射影が出てきたからユークリッド射影というのだろうか?誰か教えてください

どうやって$p$を求めるのか

アルゴリズムは以下になる、処理速度はインプットのn次元ベクトル$z$の次元数に依存し、$O(KlogK)$になるらしい。naiveとか書かれてるけど対数時間ならいいやん…

$$\begin{aligned} & \textbf{Input: \textit{z}} \\ & Sort\ \textbf{\textit{z}}\ as\ z_{(1)} \geq \cdots \geq z_{(K)} \\ & Find\ k(\textbf{\textit{z}}) := max \{ k \in [K]\ |\ 1 + kz_{(k)} > \Sigma_{j \leq k} z_{(j)} \} \\ & Define\ \tau(\textbf{\textit{z}}) = \frac{(\Sigma_{j \leq k} z_{(j)})-1}{k(z)} \\ & \textbf{Output: \textit{p}}\ \ s.t.\ p_i = [z_i - \tau(z)]_+. \\ \end{aligned}$$

上の疑似コードをRubyで書いてみた

#!/usr/bin/ruby

# Input: z
z = [23, 20, 5, 0, 8]

# Sort z as...最初のほうが大きくなるように
z = z.sort().reverse()

# Find k(z)
# z というベクターを受け取って
# (1) それぞれに対して1 + k*z[k] > z.take[k].inject(&:+) なものだけ集める
# (2) そのindexの最大値をとる
def k(z)
  k_vec = z.map.with_index(1) do |z_elem, k|
    sigma_z = z.take(k).inject(&:+)

    if 1 + k*z_elem > sigma_z
      k
    else
      nil
    end
  end.compact

  return k_vec.max
end

# Define τ(z)
def tau(z)  
  upper_bound = k(z)
  upper_bound_idx = upper_bound-1
  sigma_z = z.take(upper_bound_idx).inject(&:+)
  return (sigma_z -1).quo(upper_bound)
end

# [t]+ := max{0,t}
def t_plus(t)
  return [0, t].max
end

# Output p
p = z.map{|zi| t_plus(zi - tau(z))}
p p