今までは平方完成とか特定の分布の関数形への「当てはめ」をしてパラメータの事後分布を求めていたが,変分推論ではパラメータの分布をそれぞれの積として表し,平方完成とは違う形で関数形を求める.全分布 $p(X, Z)$ の形から関数形が自動的に決まる.
概要
観測データ $\mathbf{X}$ が与えられたときに,潜在変数の事後分布 $p(\mathbf{Z} \mid \mathbf{X})$ を求めたい.そこでまず我々人間が主観的に同時分布 $p(\mathbf{X}, \mathbf{Z})$ の関数形を勝手に決める.そこから事後分布 $p(\mathbf{Z} \mid \mathbf{X})$ とモデルエビデンス $p(\mathbf{X})$ の近似を求めたい.まず
$$\begin{align*} \ln p(\mathbf{X})=\mathcal{L}(q)+\mathrm{KL}(q | p) \end{align*}$$
と分解する.ここで
$$\begin{align*} \mathcal{L}(q) &= \int q(\mathbf{Z}) \ln \left\{ \frac{p(\mathbf{X}, \mathbf{Z})}{q(\mathbf{Z})} \right\} d \mathbf{Z} \\ \mathrm{KL}(q | p) &= -\int q(\mathbf{Z}) \ln \left\{ \frac{p(\mathbf{Z} \mid \mathbf{X})}{q(\mathbf{Z})} \right\} d \mathbf{Z} \end{align*}$$
である.そして変分下界 $\mathcal{L}(q)$ を最大化することでKL情報量を最小化する.
公式
ここで $q(\mathbf{Z})$ が
$$\begin{align*} q(\mathbf{Z})=\prod_{i=1}^{M} q_{i}(\mathbf{Z}_{i}) \end{align*}$$
のように分解できると仮定する.結論から言うと,この時変分下界を最大化するのは,各パラメータの密度関数が以下を満たす場合である.
$$\begin{align*} \ln q_{j}^{\star}(\mathbf{Z}_{j})=\mathbb{E}_{i \neq j}[\ln p(\mathbf{X}, \mathbf{Z})] + \text{ const } \end{align*}$$
多変数ガウス分布
確率変数 $z = (z_1, z_2)$ の同時分布を $p(z) = \mathcal{N}(z \mid \mu, \Lambda^{-1})$ とする.そして以下のように分割表示する($\mu = [\mu_1; \mu_2]$).
$$\begin{align*} \Lambda=\begin{bmatrix} \Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22} \end{bmatrix} \end{align*}$$
すると上の公式 より
$$\begin{align*} \ln q_{1}^{\star} (z_{1})=\mathbb{E}_{z_{2}}[\ln p(\mathbf{z})]+\text { const } \end{align*}$$
を得る.ここで求めたいのは $z_1$ の関数であるから, $z_1$ が現れない項は全て定数項に含めてしまってよい(これは他のケースでも毎回使う).よって右辺は
だけで十分.これより $q^{\star}(z_1)$ はガウス分布であることが分かる.同様に $z_2$ についても対称な式が得られて,近似関数として以下が得られる.
ここでそれぞれの平均の中に他方の平均値が入っているため,片方を更新した後もう片方を更新することを収束するまで繰 り返すことになる.
1変数ガウス分布
ガウス分布 $\mathcal{N}(x \mid \mu, \tau)$ から発生したと考えられるデータ集合 $\mathcal{D} = \{ x_1, x_2, \cdots, x_N \}$ の尤度関数は
$$\begin{align*} p(\mathcal{D} \mid \mu, \tau)=\left(\dfrac{\tau}{2 \pi}\right)^{N / 2} \exp \left\{-\dfrac{\tau}{2} \sum_{n=1}^{N} (x_{n}-\mu)^{2}\right\} \end{align*}$$
である.さらに $\mu, \tau$ の共役事前分布として
$$\begin{align*} p(\mu \mid \tau) &=\mathcal{N} (\mu \mid \mu_{0}, (\lambda_{0} \tau)^{-1}) \ p(\tau) &=\operatorname{Gam} (\tau \mid a_{0}, b_{0}) \end{align*}$$
を導入する.この3つの積が $p(\mathbf{X}, \mathbf{Z})$ である.これを使って $q(\mu), q(\tau)$ を求めてみる.
公式 より
となる.ここで1行目の右辺で $p(\tau)$ を省略しているのは, $\mu$ の関数を求めているので $\tau$ の関数は定数項にくくられるためである.よって $q(z_1)$ は以下の平均と分散を持つガウス分布となる.
$$\begin{align*} \mu_{N} &= \dfrac{\lambda_{0} \mu_{0}+N \bar{x}}{\lambda_{0}+N} \\ \lambda_{N} &= (\lambda_{0}+N) \mathbb{E}[\tau] \end{align*}$$
同様に $\tau$ については
となる($\tau$の関数を求めているので上とは違って $p(\tau)$ は省略できない).よって $q(z_2)$ はガンマ分布 $\mathrm{Gam}(a_N, b_N)$
$$\begin{align*} a_{N} &= a_{0}+\dfrac{N+1}{2} \\ b_{N} &= b_{0}+\dfrac{1}{2} \mathbb{E}_{\mu}\left[\sum_{n=1}^{N} (x_{n}-\mu)^{2}+\lambda_{0} (\mu-\mu_{0})^{2}\right] \end{align*}$$
となる.
混合ガウス分布
数式が非常にややこしくなる.観測データ $\mathbf{Z} = \{ x_1, \cdots , x_N \}$ それぞれに対してどのクラスから発生したかを表す潜在変数 $\mathbf{Z} = \{ z_1, \cdots, z_N \}$ がある.混合係数を $\boldsymbol{\pi}$ とすると
$$\begin{align*} p(\mathbf{Z} \mid \boldsymbol{\pi})=\prod_{n=1}^{N} \prod_{k=1}^{K} \pi_{k}^{z_{n k}} \end{align*}$$
これはそれぞれのデータ $z_n$ についてそれが属するクラスの確率を求めている.よって属しているクラスが既知である場合の観測データの分布は
$$\begin{align*} p(\mathbf{X} \mid \mathbf{Z}, \boldsymbol{\mu}, \boldsymbol{\Lambda})=\prod_{n=1}^{N} \prod_{k=1}^{K} \mathcal{N} (\mathbf{x}_{n} \mid \boldsymbol{\mu}_{k}, \boldsymbol{\Lambda}_{k}^{-1})^{z_{n k}} \end{align*}$$
である.次にパイパーパラメーターの事前分布として共役な事前分布を使っていく.$\boldsymbol{\pi}$ に対してはディレクレ分布
$$\begin{align*} p(\boldsymbol{\pi})=\operatorname{Dir} (\boldsymbol{\pi} \mid \boldsymbol{\alpha}_{0})=C(\boldsymbol{\alpha}_{0}) \prod_{k=1}^{K} \pi_{k}^{\alpha_{0}-1} \end{align*}$$
$\boldsymbol{\mu}, \boldsymbol{\Lambda}$ についてはガウス-ウィシャート分布
$$\begin{align*} p(\boldsymbol{\mu}, \boldsymbol{\Lambda}) = \prod_{k=1}^{K} \mathcal{N} (\boldsymbol{\mu}_{k} \mid \mathbf{m}_{0}, (\beta_{0} \boldsymbol{\Lambda}_{k})^{-1}) \mathcal{W} (\boldsymbol{\Lambda}_{k} \mid \mathcal{W}_{0}, \nu_{0}) \end{align*}$$
を用いる.以下のようにここまで出てきた分布をかけることで全分布が求められる.
パラメータの事後分布として
$$\begin{align*} q(\mathbf{Z}, \boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})=q(\mathbf{Z}) q(\boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda}) \end{align*}$$
の形を仮定する.公式 より
$$\begin{align*} \ln q^{\star}(\mathbf{Z})=\mathbb{E}_{\pi, \mu, \Lambda}[\ln p(\mathbf{X}, \mathbf{Z}, \boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})]+\text { const. } \end{align*}$$
により $q(\mathbf{Z})$ を求めていく.ここで左辺は $\mathbf{Z}$ の関数であるから,右辺のうちこれが関係する項を抜き出して残りを定数項に含めると
$$\begin{align*} \ln q^{\star}(\mathbf{Z})=\mathbb{E}_{\boldsymbol{\pi}}[\ln p(\mathbf{Z} \mid \boldsymbol{\pi})]+\mathbb{E}_{\boldsymbol{\mu}, \boldsymbol{\Lambda}}[\ln p(\mathbf{X} \mid \mathbf{Z}, \boldsymbol{\mu}, \boldsymbol{\Lambda})]+\text { const. } \end{align*}$$
となる.式を簡単にすると以下のようになる.
正規化すると
$$\begin{align*} q^{\star}(\mathbf{Z}) &= \prod_{n=1}^{N} \prod_{k=1}^{K} r_{n k}^{z_{n k}} \\ r_{nk} &= \dfrac{\rho_{nk}}{\sum_j \rho_{nj}} \end{align*}$$
が得られる.これまでと同様にこの負荷率 $\rho_{nk}$ は他のパラメータの期待値に依存するので,片方を更新した後もう片方を更新することを繰り返す.同様に $q(\boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})$ は
となるが,変数が $\boldsymbol{\pi}$ とそれ以外で分かれているので
$$\begin{align*} \ln q^{\star}(\boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda}) = q(\boldsymbol{\pi}) \prod_{k=1}^{K} q(\boldsymbol{\mu}_k, \boldsymbol{\Lambda}_k) \end{align*}$$
という形になる.ひたすら平方完成することで最終的に以下を得られる.
$\rho_{nk}$ を求める際に必要な値は以下の通りである.
変分下界
変分ベイズ法はKL情報量 $\mathrm{KL}(p | q)$ を最小化するために変分下界 $\mathcal{L}(q)$ を最大化している.この下界の値は各繰り返しのステップにおいて減少しないはずであるため,この値を計算することで実装のチェックを行うことができる.混合ガウス分布ではこの変分下界は以下のように与えられる.
ここで各項の値は以下の通りである.
予測分布
与えられたデータ集合の下で上述のように各パラメータを求めると,入力 $\boldsymbol{x}$ に対する予測分布を求めることができる.$\boldsymbol{x}$ には対応する潜在変数 $\hat{\boldsymbol{z}}$ が存在しているため,予測分布はそれらの和をとることで計算できる.
これは最終的にスチューデントのt分布の重ね合わせになる.
$$\begin{align*} p(\widehat{\mathrm{x}} \mid \mathbf{X}) &\simeq \frac{1}{\widehat{\alpha}} \sum_{k=1}^{K} \alpha_{k} \mathrm{St}\left(\widehat{\mathrm{x}} \mid \mathrm{m}_{k}, \mathrm{~L}_{k}, \nu_{k}+1-D\right) \\ \mathbf{L}_{k} &= \frac{(\nu_{k}+1-D) \beta_{k}}{(1+\beta_{k})}{\mathbf{W}_{k}} \end{align*}$$
ここで $\nu_k$ は上の変分事後分布で得られた値である.
線形回帰
入力 $\boldsymbol{x}_n$ に対して出力 $t_n$ が得られているとし,これを線形回帰モデル
$$\begin{align*} y \sim \mathcal{N}(\boldsymbol{w}^{\text{T}}, \beta^{-1}) \end{align*}$$
でモデル化する(φは適当な基底関数).wに対する尤度関数と事前分布を
$$\begin{align*} p(\mathbf{t} \mid \mathbf{w}) &= \prod_{n=1}^{N} \mathcal{N} (t_{n} \mid \mathbf{w}^{\mathrm{T}} \boldsymbol{\phi}_{n}, \beta^{-1}) \\ p(\mathbf{w} \mid \alpha) &= \mathcal{N} (\mathbf{w} \mid \mathbf{0}, \alpha^{-1} \mathbf{I}) \end{align*}$$
とする.ここでαに共役事事前分布
$$\begin{align*} p(\alpha) = \mathrm{Gam}(\alpha \mid a_0, b_0) \end{align*}$$
を導入する.すると全分布はこれらの積となる.
$$\begin{align*} p(\mathbf{t}, \mathrm{w}, \alpha) = p(\mathbf{t} \mid \mathbf{w}) p(\mathbf{w} \mid \alpha) p(\alpha) \end{align*}$$
変分事後分布
今までと同様にハイパーパラメータの分布 $q(\boldsymbol{w}, \alpha)$ を $q(\boldsymbol{w})q(\alpha)$ と変数同士の積に分解して公式に沿ってそれぞれを求めてみる.
αに依存する部分だけを取り出すと
$$\begin{align*} \ln q^{\star}(\alpha) &=\ln p(\alpha)+\mathbb{E}_{\mathbf{w}}[\ln p(\mathbf{w} \mid \alpha)]+\text { const } \\ &= (a_{0}-1) \ln \alpha-b_{0} \alpha+\dfrac{M}{2} \ln \alpha-\dfrac{\alpha}{2} \mathbb{E} [\mathbf{w}^{T} \mathbf{w}]+\text { const } \end{align*}$$
となるので
$$\begin{align*} a_{N} &=a_{0}+\dfrac{M}{2} \\ b_{N} &=b_{0}+\dfrac{1}{2} \mathbb{E} [\mathrm{w}^{\mathrm{T}} \mathrm{w}] \end{align*}$$
なるガンマ分布 $q^{\star}(\alpha) = \mathrm{Gam}(\alpha \mid a_N, b_N)$ となる.
同様に
$$\begin{align*} \ln q^{\star}(\mathbf{w}) &=\ln p(\mathbf{t} \mid \mathbf{w})+\mathbb{E}_{\alpha}[\ln p(\mathbf{w} \mid \alpha)]+\text { const } \\ &-\dfrac{1}{2} \mathbf{w}^{\mathrm{T}} (\mathbb{E}[\alpha] \mathbf{I}+\beta \Phi^{\mathrm{T}} \boldsymbol{\Phi}) \mathbf{w}+\beta \mathbf{w}^{\mathrm{T}} \boldsymbol{\Phi}^{\mathrm{T}} \mathbf{t} \\ \Phi^{\text{T}} &= [\boldsymbol{\phi}(\boldsymbol{x}_1), \cdots, \boldsymbol{\phi}(\boldsymbol{x}_M)] \end{align*}$$
であり,これは平方完成すると
$$\begin{align*} \mathrm{m}_{N} &= \beta \mathbf{S}_{N} \boldsymbol{\Phi}^{\text{T}} \mathbf{t} \\ \mathbf{S}_{N} &= (\mathbb{E}[\alpha] \mathbf{I} + \beta \mathbf{\Phi}^{\text{T}} \boldsymbol{\Phi})^{-1} \end{align*}$$
をパラメータにもつガウス分布 $q^{\star}(\boldsymbol{w}) = \mathcal{N}(\boldsymbol{m}_N, \mathbf{S}_N)$ となる.