最近なかなか統計モデルに取り組む時間が十分捻出できていないが、「ノンパラメトリックベイズ 点過程と統計的機械学習の数理」に入門を開始した。

ノンパラベイズに到達する前のデータ点をクラスに分類するクラスタリングの段階で出てきた、\(K\)-平均アルゴリズム(Wikipedia) を実装したので記録。使っているのはいつものようにJulia。

\(K\)-平均アルゴリズムではすでにクラス数 \(K \) は given として計算を進めていく。実装したアルゴリズムは本を参考にして下のようにした。


データ点を \( \mathbf{x}_i \in \R^d, i = 1, \ldots, n \), クラスタリングにより決定される、 各クラスを代表する点を \(\boldsymbol{\mu}_k \in \R^d, k = 1, 2, \ldots, K \) とする。

各データ点 \( x_i \) が属するクラスを \( z_i \in \{ 1, 2, \ldots, K \}\) とすると、 \( K \)-平均法では、 各データ点は \( (\boldsymbol{\mu}_k)_{k=1}^K \) のなかで一番(平方ユークリッド)距離が近い点が代表するクラスに分類されるので、 $$ z_i = \argmin_k || \mathbf{x}_i - \boldsymbol{\mu}_k||^2 $$ となる。

アルゴリズムは、目的関数

$$ f((z_i)_{i=1}^n, (\boldsymbol{\mu}_k)_{k=1}^K) = \sum_{i=1}^n || \mathbf{x}_i - \boldsymbol{\mu}_{z_i} ||^2 $$

が最小となるように \( (z_i)_{i=1}^n, (\boldsymbol{\mu}_k)_{k=1}^K \) を更新していく。

アルゴリズム

  1. \( (z_i)_{i=1}^n \) を乱数を用いて初期化する。集合 \( \{ z_i \mid i = 1, \ldots, n \} \) の個数が \( K \) ではない場合、個数が \( K \)になるまで \( (z_i)_{i=1}^n \) を乱数を用いて再初期化する
  2. $$ \boldsymbol{\mu}_k = \frac{1}{\sharp \{i \mid z_i = k \}} \sum_{i, z_i=k} \mathbf{x}_i $$ と代表点の更新を行う
  3. 以下のアルゴリズムを、 \( (z_i)_{i=1}^n, (\boldsymbol{\mu}_k)_{k=1}^K \) を更新することによる目的関数 \( f \) の減少が閾値以下になるまで繰り返す
    1. \( z_i = \argmin_k || \mathbf{x}_i - \boldsymbol{\mu}_k||^2, i = 1, \ldots, n \) とラベリングの更新を行う
    2. 集合 \( \{ z_i \mid i = 1, \ldots, n \} \) の個数が \( K \) ではない場合(=分類されるクラスが減少してしまった場合)、最初に戻って \( (z_i)_{i=1}^n \) の初期化からやり直す

代表点の位置によっては分類されたクラスの数が当初設定した \( K \)より小さくなってしまうことがあり、 そのまま計算を行うと、\( \boldsymbol{\mu}_k \) のいずれかが NaN になって計算できないので、 \( K \) を置き換えるか、ラベルを振り直すというのがすぐに思いつく方法だが、今回は後者の方法を採用した。


使ったデータは ScikitLearn.jl で作成。 下の図のような三種類のパターン (noisy_circles, noisy_moons, blobs) を用意して、サンプル点は1500個。 このページ が参考になる。

using ScikitLearn
using Statistics
using Plots
using Printf
using Random

@sk_import datasets: (make_circles, make_moons, make_blobs)

n_samples = 1500
noisy_circles = make_circles(n_samples=n_samples, 
                             factor=.5, noise=.05)
noisy_moons = make_moons(n_samples=n_samples, noise=.05)
blobs = make_blobs(n_samples=n_samples, random_state=8)

plts = []

for data in [noisy_circles, noisy_moons, blobs]
    points, label = data
    push!(plts, scatter(points[:, 1], points[:, 2], 
                        label="", mc=label))
end

Plots.plot(plts..., layout = (1, 3), size = [800, 400])

Dataset

本体の実装はこんな感じで、dist で \( f \) の値、 dist_change で \( f \) の変化を保持。

mutable struct KMean
    points::Array{Float64}
    K::Int64
    zs::Vector{Int64}
    μs::Array{Float64}
    dist::Float64
    dist_change::Float64
end

function KMean(points::Array{Float64}, K::Int64)
    n_points = size(points, 1)
    init_zs = nothing
    while true
        init_zs = rand(1:K, n_points)
        if Base.length(Set(init_zs)) == K
            break
        end
    end
    zs = init_zs
    kmean = KMean(points, K, zs, 
                  zeros(1, size(points, 2), K), 0, typemax(Float64))
    update_μs!(kmean)
    kmean
end

function update_μs!(km::KMean)
    km.μs = cat([mean(km.points[km.zs .== i, :], dims=1) for i in 1:km.K]..., dims=3)
    new_dist = sum((km.points - reshape(km.μs[:, :, km.zs], 2, :)').^2)
    km.dist_change = km.dist - new_dist
    km.dist = new_dist
    km
end

function update_zs!(km::KMean)
    norms = sum((km.points .- km.μs).^2, dims=2)
    argmin_norms = [x[3] for x in argmin(norms, dims=3)][:]
    
    while true
        if Base.length(Set(argmin_norms)) == km.K
            break
        end
        argmin_norms = rand(1:km.K, size(km.points, 1))
    end
    
    km.zs = argmin_norms
    km
end

function update!(km::KMean)
    update_zs!(km)
    update_μs!(km)
    km
end

それぞれのサンプルデータに対し、\( f \) の減少幅が0.1以下になるまでイテレーションを行う(最高100回)様子をアニメーションして、gifファイルに出力させると下のようになる。 Juliaはgifのアニメーションを作るのも楽でいいですね。

gifs = []

for (i, (data, K)) in enumerate(zip([noisy_circles, noisy_moons, blobs], [2, 2, 3]))
    points, label = data

    km = KMean(points, K)

    anim = @animate for n=1:100
        if abs(km.dist_change) < 0.01
            break
        end
        scatter(km.points[:, 1], km.points[:, 2], mc=km.zs, label="")
        plt = scatter!(km.μs[1, 1, :], km.μs[1, 2, :], ms=8, mc=:yellow, msw=3, label="", 
                       title=@sprintf("Iterations=%d", n))
        update!(km)
        plt
    end

    push!(gifs, gif(anim, @sprintf("clustering-kmean_%d.gif", i), fps = 1))
end

display.(gifs);

結果は下の通り。黄色の点が、\( (\boldsymbol{\mu}_k)_{k=1}^K \) に該当。

noisy_circlesnoisy_moons では次のような結果で、 代表点からの距離でクラスタリングするならばまあそうなるだろう、という結果。

noisy circles

noisy moons

blobs では一見うまくいくように見えるが、

blob 成功例

乱数による初期化次第で、下のように期待とは違う場所に収束することも起こり得る。

blob 失敗例

次は混合ガウスモデルのギブスサンプリングかな。無限次元への扉は遠い…

Jupyter Notebook (アニメーションは表示されない):
https://github.com/matsueushi/notebook_blog/blob/master/clustering.ipynb