簡體   English   中英

高效的通用Python memoize

[英]Efficient generic Python memoize

我有一個通用的Python memoizer:

cache = {}

def memoize(f): 
    """Memoize any function."""

    def decorated(*args):
        key = (f, str(args))
        result = cache.get(key, None)
        if result is None:
            result = f(*args)
            cache[key] = result
        return result

    return decorated

它有效,但我對它不滿意,因為有時效率不高。 最近,我使用了一個將列表作為參數的函數,並且顯然使用整個列表創建鍵會減慢所有內容。 最好的方法是什么? (即,有效地計算密鑰,無論args是什么,無論它們是多長還是復雜)

我想這個問題實際上是關於如何從args和泛型memoizer的函數有效地生成密鑰 - 我在一個程序中觀察到,糟糕的密鑰(生成成本太高)對運行時產生了重大影響。 我的編程用'str(args)'拍攝了45秒,但我可以用手工制作的鍵將其減少到3秒。 不幸的是,手工制作的密鑰是特定於這個編程,但我想要一個快速的記事本,我不必每次都為緩存推出特定的,手工制作的密鑰。

首先,如果您非常確定O(N)散列在這里是合理且必要的,並且您只想使用比hash(str(x))更快的算法來加快速度,請嘗試以下方法:

def hash_seq(iterable):
    result = hash(type(iterable))
    for element in iterable:
        result ^= hash(element)
    return result

當然,這對於可能很深的序列不起作用,但有一個顯而易見的方法:

def hash_seq(iterable):
    result = hash(type(iterable))
    for element in iterable:
        try:
            result ^= hash(element)
        except TypeError:
            result ^= hash_seq(element)
    return result

我不認為這是一個足夠好的哈希算法,因為它將為同一個列表的不同排列返回相同的值。 但我很確定沒有足夠好的哈希算法會快得多。 至少如果它是用C或Cython編寫的,如果這是你要去的方向,你最終可能會想做。

此外,值得注意的是,在str (或marshal )不會的情況下,這將是正確的 - 例如,如果您的list可能有一些可變元素,其repr涉及其id而不是其值。 但是,它在所有情況下仍然不正確。 特別是,它假定“迭代相同的元素”對於任何可迭代類型意味着“相等”,這顯然不能保證是真的。 假陰性並不是一個大問題,但誤報是(例如,兩個dict s具有相同的鍵但不同的值可能虛假地比較相等並共享備忘錄)。

此外,它不使用額外的空間,而是使用相當大的乘數O(N)。

無論如何,值得首先嘗試這一點,然后才決定是否值得分析是否足夠好和微調優化。

這是淺實現的一個簡單的Cython版本:

def test_cy_xor(iterable):
    cdef int result = hash(type(iterable))
    cdef int h
    for element in iterable:
        h = hash(element)
        result ^= h
    return result

從快速測試來看,純Python實現非常慢(正如您所期望的那樣,所有Python循環,與strmarshal的C循環相比),但Cython版本很容易獲勝:

    test_str(    3):  0.015475
test_marshal(    3):  0.008852
    test_xor(    3):  0.016770
 test_cy_xor(    3):  0.004613
    test_str(10000):  8.633486
test_marshal(10000):  2.735319
    test_xor(10000): 24.895457
 test_cy_xor(10000):  0.716340

只是在Cython中迭代序列並且什么都不做(實際上只是N調用PyIter_Next和一些引用計數,所以你在本機C中不會做得更好)與test_cy_xor同時是70%。 你可以通過要求一個實際的序列而不是一個可迭代來使它更快,甚至更需要一個list ,盡管它可能需要編寫顯式C而不是Cython才能獲得好處。

無論如何,我們如何解決訂購問題? 顯而易見的Python解決方案是散列(i, element)而不是element ,但所有這些元組操作都會使Cython版本減慢到12倍。 標准解決方案是在每個xor之間乘以一些數字。 但是,當你在它的時候,值得嘗試讓值很好地分散為短序列,小int元素和其他非常常見的邊緣情況。 選擇正確的數字很棘手,所以......我只是從tuple借來的。 這是完整的測試。

_hashtest.pyx:

cdef _test_xor(seq):
    cdef long result = 0x345678
    cdef long mult = 1000003
    cdef long h
    cdef long l = 0
    try:
        l = len(seq)
    except TypeError:
        # NOTE: This probably means very short non-len-able sequences
        # will not be spread as well as they should, but I'm not
        # sure what else to do.
        l = 100
    for element in seq:
        try:
            h = hash(element)
        except TypeError:
            h = _test_xor(element)
        result ^= h
        result *= mult
        mult += 82520 + l + l
    result += 97531
    return result

def test_xor(seq):
    return _test_xor(seq) ^ hash(type(seq))

hashtest.py:

import marshal
import random
import timeit
import pyximport
pyximport.install()
import _hashtest

def test_str(seq):
    return hash(str(seq))

def test_marshal(seq):
    return hash(marshal.dumps(seq))

def test_cy_xor(seq):
    return _hashtest.test_xor(seq)

# This one is so slow that I don't bother to test it...
def test_xor(seq):
    result = hash(type(seq))
    for i, element in enumerate(seq):
        try:
            result ^= hash((i, element))
        except TypeError:
            result ^= hash(i, hash_seq(element))
    return result

smalltest = [1,2,3]
bigtest = [random.randint(10000, 20000) for _ in range(10000)]

def run():
    for seq in smalltest, bigtest:
        for f in test_str, test_marshal, test_cy_xor:
            print('%16s(%5d): %9f' % (f.func_name, len(seq),
                                      timeit.timeit(lambda: f(seq), number=10000)))

if __name__ == '__main__':
    run()

輸出:

    test_str(    3):  0.014489
test_marshal(    3):  0.008746
 test_cy_xor(    3):  0.004686
    test_str(10000):  8.563252
test_marshal(10000):  2.744564
 test_cy_xor(10000):  0.904398

以下是一些提高速度的潛在方法:

  • 如果你有很多深層序列,而不是使用try around hash ,請調用PyObject_Hash並檢查-1。
  • 如果你知道你有一個序列(或者更好,特別是一個list ),而不僅僅是一個可迭代的, PySequence_ITEM (或PyList_GET_ITEM )可能會比PyIter_Next隱式使用的PyIter_Next更快。

在任何一種情況下,一旦你開始調用C API調用,通常更容易刪除Cython並只用C編寫函數。(你仍然可以使用Cython編寫一個關於該C函數的簡單包裝,而不是手動編寫擴展模塊的代碼。)那時,只需直接借用tuplehash代碼,而不是重新實現相同的算法。

如果你正在尋找一種避免O(N)的方法,那是不可能的。 如果你看一下tuple.__hash__frozenset.__hash__ImmutableSet.__hash__工作原理(最后一個是純Python而且非常易讀,順便說一句),它們都是O(N) 但是,它們都緩存哈希值。 因此,如果您經常散列相同的 tuple (而不是非相同但相等的tuple ),它會接近恆定時間。 (它是O(N/M) ,其中M是你用每個tuple調用的次數。)

如果你可以假設你的list對象永遠不會在調用之間發生變化,你顯然可以做同樣的事情,例如,使用dict映射id作為外部緩存進行hash 但總的來說,這顯然不是一個合理的假設。 (如果你的list對象永遠不會變異,那么只需切換到tuple對象就更容易了,而且不會為這些復雜性而煩惱。)

但是您可以將list對象包裝在一個子類中,該子類添加緩存的哈希值成員(或槽),並在緩存調用時調用緩存( append__setitem____delitem__等)。 然后你的hash_seq可以檢查它。

最終結果是與tuple s相同的正確性和性能:分攤O(N/M) ,除了tuple M是您使用每個相同tuple調用的次數,而對於list它是您調用的次數每個相同的list之間沒有變異。

你可以試試幾件事:

使用marshal.dumps而不是str可能會稍快一些(至少在我的機器上):

>>> timeit.timeit("marshal.dumps([1,2,3])","import marshal", number=10000)
0.008287056301007567
>>> timeit.timeit("str([1,2,3])",number=10000)
0.01709315717356219

另外,如果你的函數計算成本很高,並且可能自己都返回None,那么你的memoizing函數每次都會重新計算它們(我可能會到達這里,但我不知道更多,我只能猜測)。 結合這兩件事給出了:

import marshal
cache = {}

def memoize(f): 
    """Memoize any function."""

    def decorated(*args):
        key = (f, marshal.dumps(args))
        if key in cache:
            return cache[key]

        cache[key] = f(*args)
        return cache[key]

    return decorated

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM