[英]Efficient computation of a scale matrix (covariance matrix) in Python
The function below computes a scale matrix (covariance matrix) for a time interval (t0,t1)
for a multivariate time series. 下面的函数为多元时间序列的时间间隔(t0,t1)
计算比例矩阵(协方差矩阵)。 I would like to rewrite this function so that it does not need a list or a for loop. 我想重写此函数,以便它不需要列表或for循环。 Is there a way to do the following using only numpy array operations? 有没有办法仅使用numpy数组操作执行以下操作? It seems that I need a version of numpy.outer
that accepts 2d arrays as input and then takes the outer product along a specified axis. 似乎我需要一个numpy.outer
的版本,该版本接受2d数组作为输入,然后沿指定的轴获取外部乘积。 But I couldn't find such a function in numpy. 但是我在numpy中找不到这样的功能。
import numpy as np
def scale_matrix(multivariate_time_series, t0=0, t1=0):
# multivariate_time_series is a 2d array.
if t1==0:
t1 = len(multivariate_time_series)
a = np.mean([np.outer(multivariate_time_series[t,:],multivariate_time_series[t,:])
for t in range(t0,t1)], axis=0)
return a
You can use matrix multiplication or einsum
: 您可以使用矩阵乘法或einsum
:
>>> data = np.random.random((20, 5))
>>> t0 = t1 = 0
>>> data_r = data[t0:t1 or len(data)]
>>>
>>> data_r.T@data_r/data_r.shape[0]
array([[0.31445868, 0.15057765, 0.25087819, 0.26003647, 0.24403643],
[0.15057765, 0.32387482, 0.25741824, 0.27916451, 0.26457779],
[0.25087819, 0.25741824, 0.38244811, 0.31093482, 0.30124948],
[0.26003647, 0.27916451, 0.31093482, 0.39589237, 0.30220028],
[0.24403643, 0.26457779, 0.30124948, 0.30220028, 0.3548833 ]])
>>>
>>> np.einsum('ij,ik->jk', data_r, data_r)/data_r.shape[0]
array([[0.31445868, 0.15057765, 0.25087819, 0.26003647, 0.24403643],
[0.15057765, 0.32387482, 0.25741824, 0.27916451, 0.26457779],
[0.25087819, 0.25741824, 0.38244811, 0.31093482, 0.30124948],
[0.26003647, 0.27916451, 0.31093482, 0.39589237, 0.30220028],
[0.24403643, 0.26457779, 0.30124948, 0.30220028, 0.3548833 ]])
>>>
>>> scale_matrix(data, t0, t1)
array([[0.31445868, 0.15057765, 0.25087819, 0.26003647, 0.24403643],
[0.15057765, 0.32387482, 0.25741824, 0.27916451, 0.26457779],
[0.25087819, 0.25741824, 0.38244811, 0.31093482, 0.30124948],
[0.26003647, 0.27916451, 0.31093482, 0.39589237, 0.30220028],
[0.24403643, 0.26457779, 0.30124948, 0.30220028, 0.3548833 ]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.