[英]Python improving function speed
我正在編寫自己的腳本來計算兩個信號之間的關系。 因此,我使用mlab.csd和mlab.psd函數來計算信號的CSD和PSD。 我的陣列x的形狀為(120,68,68,815)。 我的腳本運行幾分鍾,這個功能是這個時間很長的熱點。
任何人都知道我應該做什么? 我不熟悉腳本性能的提高。 謝謝!
# to read the list of stcs for all the epochs
with open('/home/daniel/Dropbox/F[...]', 'rb') as f:
label_ts = pickle.load(f)
x = np.asarray(label_ts)
nfft = 512
n_freqs = nfft/2+1
n_epochs = len(x) # in this case there are 120 epochs
channels = 68
sfreq = 1017.25
def compute_mean_psd_csd(x, n_epochs, nfft, sfreq):
'''Computes mean of PSD and CSD for signals.'''
Rxy = np.zeros((n_epochs, channels, channels, n_freqs), dtype=complex)
Rxx = np.zeros((n_epochs, channels, channels, n_freqs))
Ryy = np.zeros((n_epochs, channels, channels, n_freqs))
for i in xrange(0, n_epochs):
print('computing connectivity for epoch %s'%(i+1))
for j in xrange(0, channels):
for k in xrange(0, channels):
Rxy[i,j,k], freqs = mlab.csd(x[j], x[k], NFFT=nfft, Fs=sfreq)
Rxx[i,j,k], _____ = mlab.psd(x[j], NFFT=nfft, Fs=sfreq)
Ryy[i,j,k], _____ = mlab.psd(x[k], NFFT=nfft, Fs=sfreq)
Rxy_mean = np.mean(Rxy, axis=0, dtype=np.float32)
Rxx_mean = np.mean(Rxx, axis=0, dtype=np.float32)
Ryy_mean = np.mean(Ryy, axis=0, dtype=np.float32)
return freqs, Rxy, Rxy_mean, np.real(Rxx_mean), np.real(Ryy_mean)
如果csd
和psd
方法是計算密集型的,那么可能會有所幫助。 有可能您可以簡單地緩存先前調用的結果並獲取它而不是多次計算。
看起來,你將有120 * 68 * 68 = 591872
個周期。
在psd計算的情況下,應該可以緩存值而沒有問題,方法只依賴於一個參數。
如果值存在,將值存儲在dict中以進行x[j]
或x[k]
檢查。 如果該值不存在,請對其進行計算並存儲。 如果值存在,則只需跳過該值並重新使用該值。
if x[j] not in cache_psd:
cache_psd[x[j]], ____ = mlab.psd(x[j], NFFT=nfft, Fs=sfreq)
Rxx[i,j,k] = cache_psd[x[j]]
if x[k] not in cache_psd:
cache_psd[x[k]], ____ = mlab.psd(x[k], NFFT=nfft, Fs=sfreq)
Ryy[i,j,k] = cache_psd[x[k]]
您可以使用csd
方法執行相同的操作。 我不太了解它可以說更多。 如果參數的順序無關緊要,您可以按排序順序存儲這兩個參數,以防止重復,如2, 1
和1, 2
。
僅當存儲器訪問時間低於計算時間和存儲時間時,使用高速緩存才能使代碼更快。 可以使用執行memoization
的模塊輕松添加此修復程序。
這是一篇關於進一步閱讀的備忘錄的文章:
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.