简体   繁体   中英

Gaussian Process Regression

I am coding a Gaussian Process regression algorithm. Here is the code:

% Data generating function

fh = @(x)(2*cos(2*pi*x/10).*x);

% range

x = -5:0.01:5;
N = length(x);

% Sampled data points from the generating function

M = 50;
selection = boolean(zeros(N,1));
j = randsample(N, M);

% mark them

selection(j) = 1;
Xa = x(j);

% compute the function and extract mean

f = fh(Xa) - mean(fh(Xa));
sigma2 = 1;

% computing the interpolation using all x's
% It is expected that for points used to build the GP cov. matrix, the
% uncertainty is reduced...

K = squareform(pdist(x'));
K = exp(-(0.5*K.^2)/sigma2);

% upper left corner of K

Kaa = K(selection,selection);

% lower right corner of K

Kbb = K(~selection,~selection);

% upper right corner of K

Kab = K(selection,~selection);

% mean of posterior

m = Kab'*inv(Kaa+0.001*eye(M))*f';

% cov. matrix of posterior

D = Kbb - Kab'*inv(Kaa + 0.001*eye(M))*Kab;

% sampling M functions from from GP

[A,B,C] = svd(Kaa);
F0 = A*sqrt(B)*randn(M,M);
% mean from GP using sampled points

F0m = mean(F0,2);
F0d = std(F0,0,2);

%%
% put together data and estimation

F = zeros(N,1);
S = zeros(N,1);
F(selection) = f' + F0m;
S(selection) = F0d;

% sampling M function from posterior

[A,B,C] = svd(D);
a = A*sqrt(B)*randn(N-M,M);
% mean from posterior GPs

Fm = m + mean(a,2);
Fmd = std(a,0,2);
F(~selection) = Fm;
S(~selection) = Fmd;

%%

figure;
% show what we got...

plot(x, F, ':r', x, F-2*S, ':b', x, F+2*S, ':b'), grid on;
hold on;
% show points we got

plot(Xa, f, 'Ok');
% show the whole curve

plot(x, fh(x)-mean(fh(x)), 'k');
grid on;

I expect to get some nice figure where the uncertainty of unknown data points would be big and around sampled data points small. I got an odd figure and even odder is that the uncertainty around sampled data points is bigger than on the rest. Can someone explain to me what I am doing wrong? Thanks!!

There are a few things wrong with your code. Here are the most important points:

  • The major mistake that makes everything go wrong is the indexing of f . You are defining Xa = x(j) , but you should actually do Xa = x(selection) , so that the indexing is consistent with the indexing you use on the kernel matrix K .

  • Subtracting the sample mean f = fh(Xa) - mean(fh(Xa)) does not serve any purpose, and makes the circles in your plot be off from the actual function. (If you choose to subtract something, it should be a fixed number or function, and not depend on the randomly sampled observations.)

  • You should compute the posterior mean and variance directly from m and D ; no need to sample from the posterior and then obtain sample estimates for those.

Here is a modified version of the script with the above points fixed.

%% Init
% Data generating function
fh = @(x)(2*cos(2*pi*x/10).*x);
% range
x = -5:0.01:5;
N = length(x);
% Sampled data points from the generating function
M = 5;
selection = boolean(zeros(N,1));
j = randsample(N, M);
% mark them
selection(j) = 1;
Xa = x(selection);

%% GP computations
% compute the function and extract mean
f = fh(Xa);
sigma2 = 2;
sigma_noise = 0.01;
var_kernel = 10;
% computing the interpolation using all x's
% It is expected that for points used to build the GP cov. matrix, the
% uncertainty is reduced...
K = squareform(pdist(x'));
K = var_kernel*exp(-(0.5*K.^2)/sigma2);
% upper left corner of K
Kaa = K(selection,selection);
% lower right corner of K
Kbb = K(~selection,~selection);
% upper right corner of K
Kab = K(selection,~selection);
% mean of posterior
m = Kab'/(Kaa + sigma_noise*eye(M))*f';
% cov. matrix of posterior
D = Kbb - Kab'/(Kaa + sigma_noise*eye(M))*Kab;

%% Plot
figure;
grid on;
hold on;
% GP estimates
plot(x(~selection), m);
plot(x(~selection), m + 2*sqrt(diag(D)), 'g-');
plot(x(~selection), m - 2*sqrt(diag(D)), 'g-');
% Observations
plot(Xa, f, 'Ok');
% True function
plot(x, fh(x), 'k');

A resulting plot from this with 5 randomly chosen observations, where the true function is shown in black, the posterior mean in blue, and confidence intervals in green.

GP估算

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM