[英]Faster alternative to for loops for einstein summation?
我有一段代碼從 Python 移植到 R。 原 Python 版本使用 np.einsum。 由於我在 R 中找不到等效的 np.einsum,並且我想確定我理解它,所以我直接使用 for 循環對其進行了編碼。 現在,我想知道是否有更快的選擇。
示例代碼:
n = 2 ; d = 3 ; nx = 4 ; v = 5
array4d <- array( runif(n*nx*v*d ,-1,0),
dim = c(n, nx, v, d) )
array3d <- array( runif(n*v*d ,-1,0),
dim = c(n, v, d) )
einsum_result <- array( rep(0, n*nx*d),
dim = c(n, nx, d))
# original Python version: np.einsum('ikl,ijkl->ijl', array3d, array4d, optimize=False)
# R version
for (i in 1: n) {
for( j in 1: nx) {
for( l in 1: d ) {
einsum_result[i, j, l] <- einsum_result[i, j, l] +
sum( array3d[i, , l] * array4d[i, j, , l])
}}}
我嘗試使用矩陣乘法刪除j
循環(因為j
/ nx
通常是最大的數字),但無法正確解決。 任何建議表示贊賞!
好的,我在使用sweep
和rowSums
消除 j 循環方面取得了一些進展,因此將其發布為答案。 我增加了 arrays 的尺寸,因為原始問題中的小數組的好處並不明顯。
n = 5 ; d = 10; nx = 400; v = 15
array4d <- array( runif(n*nx*v*d ,-1,0),
dim = c(n, nx, v, d) )
array3d <- array( runif(n*v*d ,-1,0),
dim = c(n, v, d) )
einsum_result1 <- array( rep(0, n*nx*d),
dim = c(n, nx, d))
einsum_result2 <- einsum_result1
microbenchmark({
for (i in 1: n) {
for( j in 1: nx) {
for( l in 1: d ) {
einsum_result1[i, j, l] <- einsum_result1[i, j, l] +
sum( array3d[i, , l] * array4d[i, j, , l])
}}} },
{
for (i in 1: n) {
#for( j in 1: nx) {
for( l in 1: d ) {
einsum_result2[i, , l] <- einsum_result2[i, , l] +
rowSums( sweep(array4d[i, , , l], MARGIN=2, array3d[i, , l], `*`) )
}}})
min lq mean median uq max neval cld
40.53148 42.64406 48.30703 47.65992 50.87022 179.93073 100 b
10.86891 11.09062 12.02787 11.22610 11.56781 27.29123 100 a
> identical(einsum_result1, einsum_result2)
[1] TRUE
有人看到刪除任何剩余循環並用矢量化代碼替換的方法嗎? 我已經把這個問題留了一段時間,但如果沒有建議,我想我會接受我自己的答案
看看受einsum
function 啟發的 einsum package。 它是einsum
function 並不是特別快,但是einsum_generator
function 將返回一個 cpp function (使用代碼編譯得很快)。
library(einsum)
set.seed(6)
n = 5 ; d = 10; nx = 400; v = 15
array4d <- array(runif(n*nx*v*d ,-1,0),
dim = c(n, nx, v, d))
array3d <- array(runif(n*v*d ,-1,0),
dim = c(n, v, d))
einsum_result <- array(rep(0, n*nx*d),
dim = c(n, nx, d))
einsum1 <- function(einsum_result) {
for (i in 1: n) {
for( j in 1: nx) {
for( l in 1: d ) {
einsum_result[i, j, l] <- einsum_result[i, j, l] +
sum( array3d[i, , l] * array4d[i, j, , l])
}
}
}
einsum_result
}
einsum2 <- function(einsum_result) {
for (i in 1: n) {
for( l in 1: d ) {
einsum_result[i, , l] <- einsum_result[i, , l] +
rowSums(sweep(array4d[i, , , l], MARGIN=2, array3d[i, , l], `*`) )
}
}
einsum_result
}
system.time(einsumCpp <- einsum_generator("ikl, ijkl -> ijl"))
#> user system elapsed
#> 0.07 0.07 7.83
microbenchmark::microbenchmark(einsum1 = einsum1(einsum_result),
einsum2 = einsum2(einsum_result),
einsum = einsum("ikl, ijkl -> ijl", array3d, array4d),
einsumCpp = einsumCpp(array3d, array4d),
check = "equal")
#> Unit: microseconds
#> expr min lq mean median uq max neval
#> einsum1 48950.100 52293.800 56474.49 55209.801 59354.95 76777.201 100
#> einsum2 9493.801 10342.201 11381.02 11093.451 12233.45 15554.401 100
#> einsum 168312.101 175950.002 187269.22 183263.302 192122.10 250520.301 100
#> einsumCpp 892.401 1019.101 1338.10 1356.951 1521.45 2599.701 100
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.