めっくろぐ

mechlog - メモ帳

単純なRANSACをMATLABで試す

原理のお勉強用に,MATLABでとても単純なRANSACアルゴリズムを作ってみました.

下記のPythonを参考にさせて頂き,作ってみました.
qiita.com

clear; close all;
% true values
a = 0.5;
b = 0.3;

% samples
x = 0:0.01:10;
points = a * x + b + 1 * randn(size(x)) ...
    + (randi(100,1,length(x)) == 1) .* rand(size(x)) * 1000;

%% 比較用の最小二乗法
% 参考 https://jp.mathworks.com/help/matlab/data_analysis/programmatic-fitting.html
model_poly = polyfit(x,points,1);
y_poly = polyval(model_poly,x);

%% RANSAC
model_ransac = ransac2D(x,points);
y_ransac = applyModel(model_ransac, x);

%% プロット
plot(x, points, x, y_poly, x, y_ransac)
ylim([0 10])
xlabel('x'); ylabel('y');
% 
legend1 = sprintf('raw: a = %0.3f, b = %0.3f', a, b);
legend2 = sprintf('polyfit: a = %0.3f, b = %0.3f', model_poly(1),model_poly(2));
legend3 = sprintf('RANSAC: a = %0.3f, b = %0.3f', model_ransac(1), model_ransac(2));
legend(legend1, legend2, legend3,'Location','best');

%%
function out = ransac2D(x, points)
% パラメータ
max_iter = 100; % ループの最大回数
t = 2; % 誤差の閾値
inlier_ratio = 0.8; % インライアの割合
% 
sample_num = length(points);
d = sample_num * inlier_ratio;
% 
good_models = [];
good_model_errors = [];
iteration = 0;
while iteration < max_iter
    sample.idx = randsample(sample_num,2);
    sample.idx = sort(sample.idx);
    
    sample.x = x(sample.idx);
    sample.y = points(sample.idx);
    param = getParamWithSamples(sample);
    
    inlier_num = 0;
    for p = 1: sample_num
        if p ~= sample.idx
            if getError(param, [x(p), points(p)]) <= t
                inlier_num = inlier_num + 1;
            end
        end
    end
    
    if inlier_num > d      
        for p =1: sample_num
            temp_error(p) = getError(param, [x(p), points(p)]);
        end
        current_error = mean(temp_error(:));
        
        good_models = [good_models; param(1) param(2)];
        good_model_errors = [good_model_errors; current_error];
    end
    
    iteration = iteration + 1;
end

[~,best_index] = min(good_model_errors);
out = good_models(best_index,:);
end

function param = getParamWithSamples(samples)
a = (samples.y(2)-samples.y(1)) / (samples.x(2)-samples.x(1));
b = samples.y(1) - a * samples.x(1);
param(1) = a;
param(2) = b;
end

function out = getError(model, p)
x = p(1);
y = p(2);
out = abs(applyModel(model, x) - y);
end

function out = applyModel(model, x)
a = model(1);
b = model(2);
out = a * x + b;
end


緑のRANSAC推定結果が,真値に近くなっています.