简体   繁体   中英

How to specify the prior for scikit-learn's Gaussian process regression?

As mentioned here , scikit-learn's Gaussian process regression (GPR) permits "prediction without prior fitting (based on the GP prior)". But I have an idea for what my prior should be (ie it should not simply have a mean of zero but perhaps my output, y , scales linearly with my input, X , ie y = X ). How could I encode this information into GPR?

Below is a working example, but it assumed zero mean for my prior. I read that "The GaussianProcessRegressor does not allow for the specification of the mean function, always assuming it to be the zero function, highlighting the diminished role of the mean function in calculating the posterior." I believe this is the motivation behind custom kernels (eg heteroscedastic) with variable scales at different X , although I'm still trying to better understand what capability they provide. Are there ways to get around the zero mean prior so that an arbitrary prior can be specified in scikit-learn?

import numpy as np
from matplotlib import pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

def f(x):
    """The function to predict."""
    return 1.5*(1. - np.tanh(100.*(x-0.96))) + 1.5*x*(x-0.95) + 0.4 + 1.5*(1.-x)* np.random.random(x.shape)

# Instantiate a Gaussian Process model
kernel = C(10.0, (1e-5, 1e5)) * RBF(10.0, (1e-5, 1e5))

X = np.array([0.803,0.827,0.861,0.875,0.892,0.905,
                0.91,0.92,0.925,0.935,0.941,0.947,0.96,
                0.974,0.985,0.995,1.0])
X = np.atleast_2d(X).T

# Observations and noise
y = f(X).ravel() 
noise = np.linspace(0.4,0.3,len(X))
y += noise

# Instantiate a Gaussian Process model
gp = GaussianProcessRegressor(kernel=kernel, alpha=noise ** 2,
                              n_restarts_optimizer=10)
# Fit to data using Maximum Likelihood Estimation of the parameters
gp.fit(X, y)

# Make the prediction on the meshed x-axis (ask for MSE as well)
x = np.atleast_2d(np.linspace(0.8, 1.02, 1000)).T
y_pred, sigma = gp.predict(x, return_std=True)

plt.figure() 
plt.errorbar(X.ravel(), y, noise, fmt='k.', markersize=10, label=u'Observations')
plt.plot(x, y_pred, 'k-', label=u'Prediction')
plt.fill(np.concatenate([x, x[::-1]]),
         np.concatenate([y_pred - 1.9600 * sigma,
                        (y_pred + 1.9600 * sigma)[::-1]]),
         alpha=.1, fc='k', ec='None', label='95% confidence interval')
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(0.8, 1.02)
plt.ylim(0, 5)
plt.legend(loc='lower left')
plt.show()

Here is an example on how to use the prior mean function to the sklearn GPR model.

import numpy as np
from matplotlib import pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel

A=np.linspace(5,25,num=100)
# prior mean function
prior_beta=12-0.3*A
# true function
true_beta=20-0.7*A

rng = np.random.seed(44)
# Training data
size=15
ind=np.random.randint(0,100,size=size)
# generate the posterior variance (noisy samples)
var_=np.random.uniform(0.1,10.0,size=size)
A_=A[ind][:, np.newaxis]
beta_=true_beta[ind]-prior_beta[ind]
beta_1=true_beta[ind]

plt.figure()

kernel = ConstantKernel(4) * RBF(length_scale=2, length_scale_bounds=(1e-3, 1e2))
gp = GaussianProcessRegressor(kernel=kernel,
                              alpha=var_,optimizer=None).fit(A_, beta_)
X_ = np.linspace(5, 25, 100)
y_mean, y_cov = gp.predict(X_[:, np.newaxis], return_cov=True)
# Now you add the prior mean function back
y_mean=y_mean+12-0.3*X_
plt.plot(X_, y_mean, 'k', lw=3, zorder=9, label='predicted')
plt.fill_between(X_, y_mean - 3*np.sqrt(np.diag(y_cov)),
                 y_mean + 3*np.sqrt(np.diag(y_cov)),
                 alpha=0.5, color='k', label='+-3sigma')
plt.plot(A,true_beta, 'r', lw=3, zorder=9,label='truth')
plt.plot(A,prior_beta, 'blue', lw=3, zorder=9,label='prior')
plt.errorbar(A_[:,0], beta_1, yerr=3*np.sqrt(var_), fmt='x',ecolor='g',marker='s', 
mfc='g', ms=10,capsize=6,label='training set')

plt.title("Initial: %s\n"% (kernel))
plt.legend()
plt.show()

OUTPUT

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM