ガウス過程の実装練習

 ベイズ最適化に興味がありガウス過程を学びたいと思ったので以下の本を読んだ。

 一周目なので細かい式変形は追っておらず、まず大枠の導出の流れを理解することで妥協したが、結局実装部分がわからないとピンとこないので、そこだけPythonで実装した。といっても本に疑似コードが載っており、それをほぼそのまま写すだけ。

 まず、1次元の関数 f(x) = x ^ 3 - 2 xが未知のものとしてあり、 x = -5, -4, \dots, 4, 5の11点についてノイズありで値が与えられるので関数を回帰する問題設定でやってみた。

main_1d.py

"""Gaussian Processの実装練習."""

import matplotlib.pyplot as plt
import numpy as np


def f(x: float) -> float:
    return x**3 - 2 * x


def radial_basis_function(x1: np.ndarray, x2: np.ndarray) -> float:
    diff = x1 - x2
    norm = np.linalg.norm(diff)
    theta1 = 100
    theta2 = 1
    return theta1 * np.exp(-norm / theta2)


def predict(
    x_test: np.ndarray, y_train: np.ndarray, x_train: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    assert len(x_train) == len(y_train)
    n = len(x_train)
    m = len(x_test)
    K = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            K[i][j] = radial_basis_function(x_train[i], x_train[j])

    K += 25 * np.eye(n)

    K_inv = np.linalg.inv(K)
    yy = K_inv @ y_train
    mu = []
    var = []
    for i in range(m):
        k = np.zeros((n, 1))
        for j in range(n):
            k[j, 0] = radial_basis_function(x_train[j], x_test[i])
        curr_mu = k.T @ yy
        curr_var = radial_basis_function(x_test[i], x_test[i]) - k.T @ K_inv @ k
        mu.append(curr_mu[0])
        var.append(curr_var[0, 0])
    return np.array(mu), np.array(var)


if __name__ == "__main__":
    MIN_X = -5
    MAX_X = 5
    x = np.arange(MIN_X, MAX_X, 1, dtype=np.float32)
    n = len(x)
    y = f(x)

    # ノイズ付与
    y += np.random.normal(0, 20, n)

    WIDTH = (MAX_X - MIN_X) * 0.25
    x_list = np.arange(MIN_X - WIDTH, MAX_X + WIDTH, 0.1, dtype=np.float32)
    mean_list, variance_list = predict(x_list, y, x)
    gt_list = f(x_list)

    plt.plot(x_list, mean_list, label="predict_mean")
    plt.fill_between(
        x_list, mean_list - variance_list, mean_list + variance_list, alpha=0.3
    )
    plt.plot(x_list, gt_list, linestyle=":", label="GT(y=x**3 - 2 * x)")
    plt.scatter(x, y, marker="x", label="data", color="black")
    plt.legend()
    save_path = "./plot_1d.png"
    plt.savefig(save_path, bbox_inches="tight", pad_inches=0.05)
    print(f"Saved to {save_path}")
    plt.close()

 やってみるとカーネル関数(今回だとRadial Basis Function)のパラメータを上手く設定しないと想定していたようなグラフにならず、そこに気づくのが難しかった。上の図だと手動でいくつか試して \theta _ 1 = 100, \theta _ 2 = 1という値になっている。これがどの程度それっぽいのかはよくわからない。もちろん本でもハイパーパラメータの決め方についてはきちんとフォローされており、パラメータの推定方法まできちんと載っている。微分は大変そうに見えたので、今回は適当な山登り法である程度調整した。

 xがサンプリングされる区間が一定だと見た目が面白くないのでそれもちょっと変えて、パラメータは \theta _ 1 = 1079.53, \theta _ 2 = 1.25となった。 \theta _ 1はまだまだ大きいほうが良さそうなのでちょっと怖い。

 GTのオレンジ点線がある程度は 1\sigma区間に入っているのが正当と考えると、これくらい大きくするのが当たり前か。多分これで間違ってない、と思う。まぁ真剣に使うときになったらもっとよく調べるということで。

 自分の考えている問題だと2Dまでは最低限欲しいと思われるので、xyの2Dでも試した。

main_2d.py

"""Gaussian Processの実装練習."""

import matplotlib.pyplot as plt
import numpy as np


def f(x: np.ndarray) -> float:
    mu = np.array([[0, 0]])
    sigma = [
        [0.2, 0.0],
        [0.0, 10.0],
    ]
    det = np.linalg.det(sigma)
    inv = np.linalg.inv(sigma)
    const = 1.0 / (2.0 * np.pi * np.sqrt(det))
    diff = x - mu
    exp_term = np.exp(-0.5 * diff @ inv @ diff.T)
    exp_term = np.diag(exp_term)
    return const * exp_term


def radial_basis_function(x1: np.ndarray, x2: np.ndarray) -> float:
    diff = x1 - x2
    norm = np.linalg.norm(diff)
    theta1 = 100
    theta2 = 10
    return theta1 * np.exp(-norm / theta2)


def predict(
    x_test: np.ndarray, y_train: np.ndarray, x_train: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    assert len(x_train) == len(y_train)
    n = len(x_train)
    m = len(x_test)
    K = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            K[i][j] = radial_basis_function(x_train[i], x_train[j])

    K += 25 * np.eye(n)

    K_inv = np.linalg.inv(K)
    yy = K_inv @ y_train
    mu = []
    var = []
    for i in range(m):
        k = np.zeros((n, 1))
        for j in range(n):
            k[j, 0] = radial_basis_function(x_train[j], x_test[i])
        curr_mu = k.T @ yy
        curr_var = radial_basis_function(x_test[i], x_test[i]) - k.T @ K_inv @ k
        mu.append(curr_mu[0])
        var.append(curr_var[0, 0])
    return np.array(mu), np.array(var)


if __name__ == "__main__":
    MIN_X = -5
    MAX_X = 5
    # [-5, 5]の範囲で2次元のランダムな点
    x = np.random.uniform(MIN_X, MAX_X, (20, 2))
    n = len(x)
    y = f(x)

    print(f"{x.shape=}")
    print(f"{y.shape=}")

    # ノイズ付与
    # y += np.random.normal(0, 20, n)

    fig = plt.figure(figsize=(20, 10))

    # GTをプロット
    x_list = np.arange(MIN_X, MAX_X, 0.05, dtype=np.float32)
    x_list = np.array(np.meshgrid(x_list, x_list)).T.reshape(-1, 2)
    gt_list = f(x_list)
    ax = fig.add_subplot(311, projection="3d")
    ax.scatter(x_list[:, 0], x_list[:, 1], gt_list, c=gt_list)
    ax.set_xlim(MIN_X, MAX_X)
    ax.set_ylim(MIN_X, MAX_X)
    ax.set_title("GT")

    # サンプル点をプロット
    ax = fig.add_subplot(312, projection="3d")
    ax.scatter(x[:, 0], x[:, 1], y, c=y)
    ax.set_xlim(MIN_X, MAX_X)
    ax.set_ylim(MIN_X, MAX_X)
    ax.set_title("Sample points")

    WIDTH = (MAX_X - MIN_X) * 0.25
    # [-5, 5]の範囲で0.1刻みのxの2dリスト
    x_list = np.arange(MIN_X, MAX_X, 0.1, dtype=np.float32)
    x_list = np.array(np.meshgrid(x_list, x_list)).T.reshape(-1, 2)
    mean_list, variance_list = predict(x_list, y, x)
    gt_list = f(x_list)

    print(f"{x_list.shape=}")
    print(f"{mean_list.shape=}")
    print(f"{variance_list.shape=}")
    print(f"{gt_list.shape=}")

    # 3dプロット
    ax = fig.add_subplot(322, projection="3d")
    ax.scatter(x_list[:, 0], x_list[:, 1], mean_list, c=mean_list)
    ax.set_xlim(MIN_X, MAX_X)
    ax.set_ylim(MIN_X, MAX_X)
    ax.set_title("Predict mean")

    ax = fig.add_subplot(324, projection="3d")
    ax.scatter(
        x_list[:, 0],
        x_list[:, 1],
        mean_list + variance_list,
        c=mean_list + variance_list,
    )
    ax.set_xlim(MIN_X, MAX_X)
    ax.set_ylim(MIN_X, MAX_X)
    ax.set_title("Predict mean + var")

    plt.tight_layout()

    save_path = "./plot_2d.png"
    plt.savefig(save_path, bbox_inches="tight", pad_inches=0.05)
    print(f"Saved to {save_path}")
    plt.close()

 GTとして、y方向にだけ分散値が大きい正規分布のようなものを考えて、それに対して20点分評価できるサンプル点がもらえるので、そこから関数を回帰する。

 また、ベイズ最適化をちょっと意識して、予測した平均と分散を足したもの(簡易UCBみたいな気持ち)もプロットした。

 3Dでのプロットにしてしまったのでちょっと見づらいところもあるが、まぁだいたい近似はできていそうで、mean + varだと手薄になっている奥のところがちゃんと優先されそうというのも見えたのでだいたい満足。

 しかし、ベイズ最適化の文脈になると、ここから次に調べるべき点をUCBだったりEIだったりを基準にして決めていくわけだが、結局そういう代理関数の最適化パートが挟まることになり、これの極大点自体は解析的には求まらない(と理解している)ので、勾配法やサンプリングで見つけていかないといけない。そこでまた計算コストがかかってしまうので、やっぱり大元の対象がとても高コストのときにこそ有効なのかとは感じた。

 簡単な適用はできないかもしれないということがわかっただけでも収穫なので、ガウス過程は手段の一つとして取り出せる程度に頭の一部に置いておくこととして、いったん切り上げか。