简体   繁体   English

在 SciPy 中生成 B 样条基,如 R 中的 bs()

[英]Generate a B-Spline basis in SciPy, like bs() in R

With N 1-dimensional data X, I would like to evaluate each point at K cubic B-splines.对于 N 个一维数据 X,我想在 K 三次 B 样条上评估每个点。 In R, there is a simple function with an intuitive API, called bs .在 R 中有一个简单的 function 和一个直观的 API,称为bs There is actually a python package patsy which replicates this , but I can't use that package -- only scipy and such. There is actually a python package patsy which replicates this , but I can't use that package -- only scipy and such.

Having looked through the scipy.interpolate documentation on spline-related functions, the closest I can find is BSpline, or BSpline.basis_element, but how to get just the K basis functions is totally mysterious to me.浏览了 scipy.interpolate 关于样条相关函数的文档后,我能找到的最接近的是 BSpline 或 BSpline.basis_element,但如何仅获得 K 基函数对我来说完全是神秘的。 I tried the following:我尝试了以下方法:

import numpy as np
import scipy.interpolate as intrp
import matplotlib.pyplot as plt
import patsy # for comparison

# in Patsy/R: nice and sensible
x = np.linspace(0., 1., 100)
y = patsy.bs(x, knots=np.linspace(0,1,4), degree=3)
plt.subplot(1,2,1)
plt.plot(x,y)
plt.title('B-spline basis')

# in scipy: ?????
y_py = np.zeros((x.shape[0], 6))
for i in range(6):
    y_py[:,i] = intrp.BSpline(np.linspace(0,1,10),(np.arange(6)==i).astype(float), 3, extrapolate=False)(x)


plt.subplot(1,2,2)
plt.plot(x,y_py)
plt.title('Something else')


在此处输入图像描述

It doesn't work, and makes me realise I don't actually know what this function is doing.它不起作用,让我意识到我实际上并不知道这个 function 在做什么。 First of all, it will not accept fewer than 8 interior knots, which I don't understand why.首先,它不会接受少于 8 个内部结,我不明白为什么。 Secondly, it only thinks that the splines are defined within (1/3, 2/3)ish range, which maybe means that it is ignoring the first 3 and last 3 knot values for some reason?其次,它只认为样条线定义在 (1/3, 2/3)ish 范围内,这可能意味着它出于某种原因忽略了前 3 个和后 3 个节点值? Do I need to pad the knots?我需要垫结吗?

Any help would be appreciated!任何帮助,将不胜感激!

EDIT: I have solved this discrepancy, indeed it seems like BSpline ignore the first 3 and last 3 values of knots.编辑:我已经解决了这个差异,实际上似乎 BSpline 忽略了结的前 3 个和后 3 个值。 I'm still interested in knowing why there is this discrepancy, so that I feel less bad for the odd hour spent debugging a strange interface.我仍然很想知道为什么会出现这种差异,这样我就不会因为花几个小时调试一个奇怪的界面而感到难过。

For posterity, here is the code that does produce the basis functions对于后代,这里是产生基函数的代码

import numpy as np
import scipy.interpolate as intrp
import matplotlib.pyplot as plt
import patsy # for comparison

these_knots = np.linspace(0,1,5)

# in Patsy/R: nice and sensible
x = np.linspace(0., 1., 100)
y = patsy.bs(x, knots=these_knots, degree=3)
plt.subplot(1,2,1)
plt.plot(x,y)
plt.title('B-spline basis')

# in scipy: ?????
numpyknots = np.concatenate(([0,0,0],these_knots,[1,1,1])) # because??
y_py = np.zeros((x.shape[0], len(these_knots)+2))
for i in range(len(these_knots)+2):
    y_py[:,i] = intrp.BSpline(numpyknots, (np.arange(len(these_knots)+2)==i).astype(float), 3, extrapolate=False)(x)

plt.subplot(1,2,2)
plt.plot(x,y_py)
plt.title('In SciPy')


Looks like you already found the answer, but to clarify why these you need to define the multiple knots at the edges, you can read the scipy docs .看起来您已经找到了答案,但要阐明为什么需要在边缘定义多个结,您可以阅读scipy 文档 They are defined using the Cox-de Boor recursive formula.它们是使用 Cox-de Boor 递归公式定义的。 This formula starts with defining neighbouring support domains between the given knot points with a constant value of 1 (zeroth order).该公式从定义给定节点之间的相邻支持域开始,常数值为 1(零阶)。 These are convoluted to acquire the higher order basis functions.这些被卷积以获得更高阶的基函数。 Hence two domains make one first order basis function, three domains make one second order basis function and four domains (= 5 knot points) make one third order basis function that is supported within the range of these 5 knot points.因此,两个域构成一个一阶基 function,三个域构成一个二阶基 function,四个域(= 5 个结点)构成一个三阶基 ZC1C425268E68385D1AB5074C17A94 支持这些 5 个结点的范围内。 If you want n basis functions of degree k = 3, you will need to have (n+k+1) knot points.如果你想要 n 个 k = 3 的基函数,你需要有 (n+k+1) 个节点。

The minimum of 8 knots is such that n >= k + 1, which gives 2 * (k+1). 8 节的最小值使得 n >= k + 1,得到 2 * (k+1)。 The base interval t[k]... t[n] in scipy is the only range where you can define full degree basis functions. scipy 中的基本区间 t[k]...t[n] 是唯一可以定义全度基函数的范围。 To make sure that this base interval reaches the outer knot points, the two end knots are usually given a multiplicity of (k+1).为了确保这个基本间隔到达外部结点,通常给两个末端结点的重数为 (k+1)。 Probably scipy only showed this base interval in your 'Something else' result.可能 scipy 仅在您的“其他”结果中显示此基本间隔。

Note that you can also get the basis functions using请注意,您还可以使用

y_py[:,i] = intrp.BSpline.basis_element(numpyknots[i:i+5], extrapolate=False)(x)

this also removes the difference at x = 1.这也消除了 x = 1 处的差异。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM