![](/img/trans.png)
[英]How to exploit symmetry in outer product in Numpy (or other Python solutions)?
[英]How to exploit permutational symmetry in this loop?
我有一個標量函數f(a,b,c,d)
,它具有以下排列對稱性
f(a,b,c,d) = f(c,d,a,b) = -f(b,a,d,c) = -f(d,c,b,a)
我用它來完全填充4D陣列。 下面的代碼(使用python / NumPy)有效:
A = np.zeros((N,N,N,N))
for a in range(N):
for b in range(N):
for c in range(N):
for d in range(N):
A[a,b,c,d] = f(a,b,c,d)
但顯然我想利用對稱性來減少這部分代碼的執行時間。 我試過了:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
for d in range(N):
cd += 1
if ab >= cd:
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
這將執行時間縮短了一半。 但是為了最后的對稱,我嘗試了:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
for d in range(N):
cd += 1
if ab >= cd:
if ((a >= b) or (c >= d)):
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
A[b,a,d,c] = A[d,c,b,a] = -A[a,b,c,d]
哪個有效,但不會讓我接近另外兩個加速的因素。 我不認為這是正確的理由,但不明白為什么。
我怎樣才能更好地利用這種特定的排列對稱?
有趣的問題!
對於N=3
,應該有81個具有4個元素的組合。 使用循環,您可以創建156。
看起來重復的主要來源是or
在(a >= b) or (c >= d)
,它太寬容了。 (a >= b) and (c >= d)
限制性太強。
但是你可以比較a + c >= b + d
。 要獲得幾毫秒(如果有的話),你可以在第三個循環中保存a + c
作為ac
:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
ac = a + c
for d in range(N):
cd += 1
if (ab >= cd and ac >= b+d):
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
A[b,a,d,c] = A[d,c,b,a] = -A[a,b,c,d]
使用此代碼,我們創建了112個組合,因此與您的方法相比,重復次數更少,但可能仍會有一些優化。
這是我用來計算創建組合數的代碼:
from itertools import product
N = 3
ab = 0
all_combinations = set(product(range(N), repeat=4))
zeroes = ((x, x, y, y) for x, y in product(range(N), repeat=2))
calculated = list()
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
ac = a + c
for d in range(N):
cd += 1
if (ab >= cd and ac >= b + d) and not (a == b and c == d):
calculated.append((a, b, c, d))
calculated.append((c, d, a, b))
calculated.append((b, a, d, c))
calculated.append((d, c, b, a))
missing = all_combinations - set(calculated) - set(zeroes)
if missing:
print "Some sets weren't calculated :"
for s in missing:
print s
else:
print "All cases were covered"
print len(calculated)
有and not (a==b and c==d)
,數字下降到88。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.