Gaussian Splattingを試す

概要理解

 流石に中身を全く理解しないままに動かして結果だけ得るのもつまらないと思って、簡単なガウス球3つをレンダリングする部分だけ簡単に実装して原理を確認した。

実装と結果

サンプル実装

from dataclasses import dataclass
import torch
import roma


@dataclass
class GaussianPoint:
    position = torch.zeros((3))
    rotation = torch.tensor([0.0, 0.0, 0.0, 1.0])
    scale = torch.ones((3))
    color = torch.zeros((3))
    opacity = torch.tensor(0.5)

    # when projected to the screen
    i = torch.tensor(0.0)
    j = torch.tensor(0.0)
    z = torch.tensor(0.0)
    sigma_2d = torch.zeros((2, 2))


def render_image(T: torch.Tensor, # extrinsic translation (3,)
                 W: torch.Tensor, # extrinsic orientation (3, 3)
                 K: torch.Tensor, # intrinsic (3, 3)
                 height: int,
                 width: int,
                 gaussian_points: list[GaussianPoint]
                 ) -> torch.Tensor:
    for gaussian in gaussian_points:
        # convert to camera coordinate
        p_in_world = gaussian.position
        p_from_camera = W.T @ (p_in_world - T)
        p_in_screen = K @ p_from_camera
        i = p_in_screen[1] / p_in_screen[2]
        j = p_in_screen[0] / p_in_screen[2]
        if p_from_camera[2] <= 0:
            gaussian.z = float("inf")
            continue

        R = roma.unitquat_to_rotmat(gaussian.rotation)
        S = torch.diag(gaussian.scale)
        sigma_3d = R @ S @ S.T @ R.T

        # ref. EWA Volume Splatting p.6
        # https://www.cs.umd.edu/~zwicker/publications/EWAVolumeSplatting-VIS01.pdf
        # (with scale of K)
        l = torch.norm(p_from_camera)
        u0 = p_from_camera[0]
        u1 = p_from_camera[1]
        u2 = p_from_camera[2]
        J = torch.zeros((3, 3))
        J[0, 0] = 1 / u2
        J[1, 1] = 1 / u2
        J[0, 2] = -u0 / (u2 ** 2)
        J[1, 2] = -u1 / (u2 ** 2)
        J[2, 0] = u0 / l
        J[2, 1] = u1 / l
        J[2, 2] = u2 / l

        sigma_2d = J @ W @ sigma_3d @ W.T @ J.T
        sigma_2d = sigma_2d[0:2, 0:2]
        sigma_2d = K[0:2, 0:2] @ sigma_2d

        # update
        gaussian.i = i
        gaussian.j = j
        gaussian.z = u2
        gaussian.sigma_2d = sigma_2d

    # sort by z
    gaussian_points = sorted(gaussian_points, key=lambda x: x.z)

    rem_alpha = torch.ones((height, width), dtype=torch.float32)
    pred_image = torch.zeros((height, width, 3), dtype=torch.float32)
    I, J = torch.meshgrid(torch.arange(
        height), torch.arange(width), indexing="ij")
    for gaussian in gaussian_points:
        if gaussian.z == float("inf"):
            continue
        i = gaussian.i
        j = gaussian.j
        u2 = gaussian.z
        sigma_2d = gaussian.sigma_2d

        # calc inv and det
        sigma_2d_inv = torch.inverse(sigma_2d)
        sigma_2d_det = torch.det(sigma_2d)

        # calc diff vec
        diff_i = I - i
        diff_j = J - j
        diff_vec = torch.stack([diff_j, diff_i], dim=2)  # [h, w, 2]
        diff_vec = diff_vec.unsqueeze(-2)  # [h, w, 2] -> [h, w, 1, 2]

        # calc 2d gauss
        # [h, w, 1, 2] @ [2, 2] @ [h, w, 2, 1] -> [h, w, 1, 1]
        exp = -0.5 * diff_vec @ sigma_2d_inv @ diff_vec.transpose(-1, -2)
        exp = exp.squeeze()  # [h, w, 1, 1] -> [h, w]
        factor = 1 / (2 * torch.pi * sigma_2d_det)
        gauss = factor * torch.exp(exp)

        # calc color
        curr_weight = rem_alpha * gaussian.opacity * gauss
        add_color = curr_weight.unsqueeze(-1) * gaussian.color * 50000
        pred_image += add_color
        rem_alpha *= 1.0 - curr_weight

    pred_image = torch.clamp(pred_image, 0, 1)
    pred_image *= 255
    pred_image = pred_image.to(torch.uint8)
    return pred_image

 簡単に3つのGaussianを浮かべて実行した。

  • 青:X方向(画像左右方向)だけ分散大
  • 緑:Y方向(画像上下方向)だけ分散大
  • 赤:Z方向(画像奥行き方向)だけ分散大

 赤いのが横に伸びて見えるのは、左前方にあるものを斜めから眺めるような形になっているからだと思うが、もしかしたらなにか間違えているかもしれない。自信はない。

知見

 射影変換が非線形なのでテイラー展開で近似するというのはなるほど納得だった。この実装で、大本のEWA Volume Splattingの論文だと、カメラ座標系から光線座標系への変換において、Z座標として座標までの距離としている(なので射影変換というか純粋に座標変換?)が、Gaussian Splattingのcuda_rasterizerの実装だとよくある射影変換を偏微分した行列になっており、それで良いらしい。自分の実装だとカメラの内部パラメータをいつかければ良いのかがわからなくて2x2行列にした後でかけてみていた(そうでないと2回かかってしまう気がした)。

 色の計算がよくわからず、そのままだとGaussianの値が小さすぎるので色が付かなかった。仕方なく50000などとんでもない係数をかけているが、これは妙だ。まだ間違っているところはあるのだろうが、あまりデバッグに時間かける気力が出なかったのでここで切り上げた。

疑問点

  • ガウス球なので、不透明度を1にしても、すでに中央は確率密度としては1にならない。なので、完全に不透明な赤い球を表現するのに、ガウス球が一個では足りないし、薄くなっていくことも考えるとなんだかとても不揃いなガウス球をたくさん浮かべないといけなさそう。多分実際は完全な球とかは少ないので問題ない? というか確率密度関数を使うので良いのだっけ? ピクセルが1x1の面積を持つので、確率密度が確率の積分値と一致するということ?

実装しなかった部分

 ガウス球のレンダリングというところまでは20年前にもあるような話ではあるので、Gaussian Splatting論文の貢献部分とは言い難いかもしれない。Gaussian Splatting論文だと、その他ガウス球の生成削除アルゴリズム部分と、レンダリングの高速化パートという概ね2つの貢献があると見なせると思う。

 特に高速化パートの実装が大変そうだ。必要な処理として

  1. 画像を小さなタイルに分割し、そのタイルごとにガウス球がその範囲に十分被っているかチェックする
  2. ガウス球がどのタイルに属するかという情報と奥行き情報を1つの変数にパッキングする
  3. パッキングした変数についてGPU基数ソートを行う
  4. ソートされた結果から、タイルごとに分割して前からαブレンディングをする

などがあり、さらに勾配を計算することなどを考えると実装がとんでもないことになりそう。論文だけ見せられて「再現実装をしなさい」と言われたら、まず半年専念することは覚悟して、それでも全然できないかもしれない。ところで公式実装はライセンスが商用利用不可なので、頑張る人は再現実装をするのだろうか……。

公式を動かす

 AWSIMで取得したデータで試した。PCDとして西新宿の点群マップデータがあるので、それを読み込んで初期値として利用するようにした。scene/dataset_readers.pyに追加で読み込むようなコードを書けばとりあえず動いた。

使用したコード片

def readAutowareSceneInfo(path):
    import pandas as pd
    from scipy.spatial.transform import Rotation
    import yaml
    from copy import deepcopy
    camera_info = yaml.load(open(os.path.join(path, "camera_info.yaml"), 'r'), Loader=yaml.FullLoader)
    width = camera_info["width"]
    height = camera_info["height"]
    K = np.array(camera_info["K"]).reshape(3, 3)
    fx = K[0, 0]
    fy = K[1, 1]
    FovX = 2 * np.arctan(width / (2 * fx))
    FovY = 2 * np.arctan(height / (2 * fy))
    df_cam_pose = pd.read_csv(os.path.join(path, "pose.tsv"), sep="\t")
    train_cam_infos = []
    T0 = None
    min_pos = np.min(df_cam_pose[["x", "y", "z"]].values, axis=0)
    max_pos = np.max(df_cam_pose[["x", "y", "z"]].values, axis=0)
    for i in range(0, len(df_cam_pose), 3):
        row = df_cam_pose.iloc[i]
        R = Rotation.from_quat([row.qx, row.qy, row.qz, row.qw]).as_matrix()
        T = np.array([row.x, row.y, row.z])
        if i == 0:
            T0 = deepcopy(T)
        T -= T0
        c2w = np.eye(4)
        c2w[0:3, 0:3] = R
        c2w[0:3, 3] = T
        w2c = np.linalg.inv(c2w)
        R = w2c[0:3, 0:3].T
        T = w2c[0:3, 3]
        image_path = os.path.join(path, "images", f"{i:08d}.png")
        image = Image.open(image_path)
        train_cam_infos.append(
            CameraInfo(uid=i,
                       R=R,
                       T=T,
                       FovY=FovY,
                       FovX=FovX,
                       image=image,
                       image_path=image_path,
                       image_name=f"{i:08d}",
                       width=width,
                       height=height,
                       )
        )
    test_cam_infos = train_cam_infos

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "points3D.ply")
    if not os.path.exists(ply_path):
        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
        import open3d as o3d
        loaded_pcd = o3d.io.read_point_cloud(f"{path}/pointcloud_map.pcd")

        print(f"before downsample: {len(loaded_pcd.points)}")
        margin = 100  # [m]
        min_bound = min_pos - margin
        max_bound = max_pos + margin
        print(f"{min_bound=}, {max_bound=}")
        bbox = o3d.geometry.AxisAlignedBoundingBox(min_bound, max_bound)
        loaded_pcd = loaded_pcd.crop(bbox)
        print(f"after  downsample: {len(loaded_pcd.points)}")

        xyz = np.asarray(loaded_pcd.points)
        xyz -= T0
        num_pts = len(xyz)

        shs = np.random.random((num_pts, 3)) / 255.0
        storePly(ply_path, xyz, SH2RGB(shs) * 255)
    try:
        pcd = fetchPly(ply_path)
    except:
        print("Error reading ply file")
        exit(1)

    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info

結果

 あまり良い結果にならなかった。特に直進している場合に、なにも存在していないはずの道路上に浮いているものが多い。学習時の表示ではPSNR25あたり出ているということだったが、途中の崩壊しているところではそこまであるとは思えない。序盤の静止中などはきれいなので、平均値としてはそこで底上げされてしまっているのかもしれない。

 流石に単眼カメラの、視野に重複が多くはないシーンでは難しかったか。6方向カメラとかがあればまだ違うかもしれないが、これ以上試すのは大変なので、とりあえずここで切り上げる。