简体   繁体   English

记住[Integer]类型的函数 - > a

[英]Memoizing a function of type [Integer] -> a

My problem is how to efficiently memoize an expensive function f :: [Integer] -> a that is defined for all finite lists of integers and has the property f . sort = f 我的问题是如何有效地记忆昂贵的函数f :: [Integer] -> a为所有有限的整数列表定义并具有属性f . sort = f f . sort = f ? f . sort = f

My typical use case is that given a list as of integers I need to obtain the values f (a:as) for various Integer a, so I'd like to build up simultaneously a directed labelled graph whose vertices are pairs of an Integer list and its function value. 我典型的使用情况是,给定一个列表as整数,我需要获得的值f (a:as)各种整数,所以我想建立一个同时向标定图的顶点是对的整数列表及其功能价值。 An edge labelled by a from (as, f as) to (bs, f bs) exists if and only if a:as = bs. 当且仅当a:as = bs时,存在由from(as,f as)到(bs,f bs)标记的边。

Stealing from a brilliant answer by Edward Kmett I simply copied 爱德华凯梅特精彩答案中窃取我只是复制了

{-# LANGUAGE BangPatterns #-}
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
  fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
  (q,0) -> index l q
  (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
  where go !n !s = Tree (go l s') n (go r s')
          where l = n + s
                r = l + s
                s' = s * 2

and adapted his idea to my problem as 并根据我的问题调整了他的想法

-- directed graph labelled by Integers
data Graph a = Graph a (Tree (Graph a))
instance Functor Graph where
  fmap f (Graph a t) = Graph (f a) (fmap (fmap f) t)

-- walk the graph following the given labels
walk :: Graph a -> [Integer] -> a
walk (Graph a _) [] = a
walk (Graph _ t) (x:xs) = walk (index t x) xs

-- graph of all finite integer sequences
intSeq :: Graph [Integer]
intSeq = Graph [] (fmap (\n -> fmap (n:) intSeq) nats)

-- could be replaced by Data.Strict.Pair
data StrictPair a b = StrictPair !a !b
  deriving Show

-- f = sum modified according to Edward's idea (the real function is more complicated)
g :: ([Integer] -> StrictPair Integer [Integer]) -> [Integer] -> StrictPair Integer [Integer]
g mf [] = StrictPair 0 []
g mf (a:as) = StrictPair (a+x) (a:as)
  where StrictPair x y = mf as

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) intSeq

g_m :: [Integer] -> StrictPair Integer [Integer]
g_m = walk g_graph

This works OK, but as the function f is independent of the order of the occurring integers (but not of their counts) there should be only one vertex in the graph for all integer lists equal up to ordering. 这可以正常工作,但由于函数f与出现的整数的顺序无关(但不是它们的计数),因此在图中只有一个顶点,所有整数列表都等于排序。

How do I achieve this? 我该如何实现这一目标?

How about just defining g_m' = g_m . sort 如何定义g_m' = g_m . sort g_m' = g_m . sort , ie you simply sort the input list first before calling your memoized function? g_m' = g_m . sort ,即您只是在调用memoized函数之前先对输入列表进行排序?

I have a feeling this is the best you can do since if you want your memoized graph to consist of only sorted paths someone is going to have to look at all of the elements of the list before constructing the path. 我觉得这是你能做的最好的事情,因为如果你想让你的memoized图形只包含有条件的路径,那么在构造路径之前,某人必须要查看列表中的所有元素。

Depending on what your input lists look like it might be helpful to transform them in a way which makes the trees branch less. 根据输入列表的外观,以一种使树分支更少的方式转换它们可能会有所帮助。 For instance, you might try sorting and taking differences: 例如,您可能会尝试排序并采取差异:

original input list:   [8,3,14,8,5]
sorted:                [3,3,8,8,14]
diffed:                [3,0,5,0,6] -- use this as the key

The transformation is a bijection, and the trees branch less because there are smaller numbers involved. 转换是一个双射,树的分支较少,因为涉及的数量较少。

You can use a bit different approach. 您可以使用一些不同的方法。 There is a trick in proof that a finite product of countable sets is countable: 有证据证明可数集的有限产品是可数的:

We can map the sequence [a1, ..., an] to Nat by product . zipWith (^) primes 我们可以按顺序将序列[a1, ..., an]映射到Nat product . zipWith (^) primes product . zipWith (^) primes : 2 ^ a1 * 3 ^ a2 * 5 ^ a3 * ... * primen ^ an . product . zipWith (^) primes2 ^ a1 * 3 ^ a2 * 5 ^ a3 * ... * primen ^ an

To avoid problems with sequences with zero at the end, we can increase the last index. 为了避免最后的序列为零的问题,我们可以增加最后一个索引。

As the sequence is ordered, we can exploit the property as user5402 mentioned. 由于序列是有序的,我们可以像提到的user5402一样利用该属性。

The benefit of using the tree, is that you can increase branching to speed-up traversal. 使用树的好处是,您可以增加分支以加速遍历。 OTOH prime trick could make indexes quite big, but hopefully some tree paths will just be unexplored (remain as thunks). OTOH主要技巧可以使索引相当大,但希望一些树路径只是未开发(保持为thunk)。

{-# LANGUAGE BangPatterns #-}

-- Modified from Kmett's answer:
data Tree a = Tree a (Tree a) (Tree a) (Tree a) (Tree a)
instance Functor Tree where
  fmap f (Tree x a b c d) = Tree (f x) (fmap f a) (fmap f b) (fmap f c) (fmap f d)

index :: Tree a -> Integer -> a
index (Tree x _ _ _ _) 0 = x
index (Tree _ a b c d) n = case (n - 1) `divMod` 4 of
  (q,0) -> index a q
  (q,1) -> index b q
  (q,2) -> index c q
  (q,3) -> index d q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree n (go a s') (go b s') (go c s') (go d s')
            where
                a = n + s
                b = a + s
                c = b + s
                d = c + s
                s' = s * 4

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- Primes -- https://www.haskell.org/haskellwiki/Prime_numbers
-- Generation and factorisation could be done much better
minus (x:xs) (y:ys) = case (compare x y) of
           LT -> x : minus  xs  (y:ys)
           EQ ->     minus  xs     ys
           GT ->     minus (x:xs)  ys
minus  xs     _     = xs

primes = 2 : sieve [3..] primes
  where
    sieve xs (p:ps) | q <- p*p , (h,t) <- span (< q) xs =
                   h ++ sieve (t `minus` [q, q+p..]) ps

addToLast :: [Integer] -> [Integer]
addToLast [] = []
addToLast [x] = [x + 1]
addToLast (x:xs) = x : addToLast xs

subFromLast :: [Integer] -> [Integer]
subFromLast [] = []
subFromLast [x] = [x - 1]
subFromLast (x:xs) = x : subFromLast xs

addSubProp :: [NonNegative Integer] -> Property
addSubProp xs = xs' === subFromLast (addToLast xs')
  where xs' = map getNonNegative xs

-- Trick from user5402 answer
toDiffList :: [Integer] -> [Integer]
toDiffList = toDiffList' 0
  where toDiffList' _ [] = []
        toDiffList' p (x:xs) = x - p : toDiffList' x xs

fromDiffList :: [Integer] -> [Integer]
fromDiffList = fromDiffList' 0
  where fromDiffList' _ [] = []
        fromDiffList' p (x:xs) = p + x : fromDiffList' (x + p) xs

diffProp :: [Integer] -> Property
diffProp xs = xs === fromDiffList (toDiffList xs)

listToInteger :: [Integer] -> Integer
listToInteger = product . zipWith (^) primes . addToLast

integerToList :: Integer -> [Integer]
integerToList = subFromLast . impl primes 0
  where impl _      _ 0 = []
        impl _      0 1 = []
        impl _      k 1 = [k]
        impl (p:ps) k n = case n `divMod` p of
                            (n', 0) -> impl (p:ps) (k + 1) n'
                            (_,  _) -> k : impl ps 0 n

listProp :: [NonNegative Integer] -> Property
listProp xs = xs' === integerToList (listToInteger xs')
  where xs' = map getNonNegative xs

toIndex :: [Integer] -> Integer
toIndex = listToInteger . toDiffList

fromIndex :: Integer -> [Integer]
fromIndex = fromDiffList . integerToList

-- [1,0] /= [0]
-- Decreasing sequence!
doesntHold :: [NonNegative Integer] -> Property
doesntHold xs = xs' === fromIndex (toIndex xs')
  where xs' = map getNonNegative xs

holds :: [NonNegative Integer] -> Property
holds xs = xs' === fromIndex (toIndex xs')
  where xs' = sort $ map getNonNegative xs

g :: ([Integer] -> Integer) -> [Integer] -> Integer
g mg = g' . sort
  where g' [] = 0
        g' (x:xs)  = x + sum (map mg $ tails xs)

g_tree :: Tree Integer
g_tree = fmap (g faster_g' . fromIndex) nats

faster_g' :: [Integer] -> Integer
faster_g' = index g_tree . toIndex

faster_g = faster_g' . sort

On my machine fix g [1..22] feels slow, when faster_g [1..40] is still blazing fast. 在我的机器上fix g [1..22]感觉很慢,当faster_g [1..40]仍然快速燃烧时。


Addition : if we have bounded set (with indexes 0..n-1 ) , we can encode it as: a0 * n^0 + a1 * n^1 ... . 另外如果我们有有界集(索引为0..n-1 ),我们可以将其编码为: a0 * n^0 + a1 * n^1 ...

We can encode any Integer as binary list, eg 11 is [1, 1, 0, 1] (least bit first). 我们可以将任何Integer编码为二进制列表,例如11[1, 1, 0, 1] (最少位优先)。 Then if we separate integers in the list with 2 , we get sequence of bounded values. 然后,如果我们将列表中的整数与2分开,我们得到有界值的序列。

As bonus we can take the sequence of 0, 1, 2 digits and compress it to binary using eg Huffman encoding, as 2 is much rarer than 0 or 1. But this might be overkill. 作为奖励,我们可以采用0,1,2位的序列,并使用例如霍夫曼编码压缩为二进制,因为2比0或1更罕见。但这可能是过度的。

With this trick, indexes stay much smaller and the space probably is better packed. 有了这个技巧,索引保持更小,空间可能更好。

{-# LANGUAGE BangPatterns #-}

-- From Kment's answer:
import Data.Function (fix)
import Data.List (sort, tails)
import Data.List.Split (splitOn)
import Test.QuickCheck

{-- Tree definition as before --}

-- 0, 1, 2
newtype N3 = N3 { unN3 :: Integer }
  deriving (Eq, Show)

instance Arbitrary N3 where
  arbitrary = elements $ map N3 [ 0, 1, 2 ]

-- Integer <-> N3
coeffs3 :: [Integer]
coeffs3 = coeffs' 1
  where coeffs' n = n : coeffs' (n * 3)

listToInteger :: [N3] -> Integer
listToInteger = sum . zipWith f coeffs3
  where f n (N3 m) = n * m

listFromInteger :: Integer -> [N3]
listFromInteger 0 = []
listFromInteger n = case n `divMod` 3 of
  (q, m) -> N3 m : listFromInteger q

listProp :: [N3] -> Property
listProp xs = (null xs || last xs /= N3 0) ==> xs === listFromInteger (listToInteger xs)

-- Integer <-> N2

-- 0, 1
newtype N2 = N2 { unN2 :: Integer }
  deriving (Eq, Show)

coeffs2 :: [Integer]
coeffs2 = coeffs' 1
  where coeffs' n = n : coeffs' (n * 2)

integerToBin :: Integer -> [N2]
integerToBin 0 = []
integerToBin n = case n `divMod` 2 of
  (q, m) -> N2 m : integerToBin q

integerFromBin :: [N2] -> Integer
integerFromBin = sum . zipWith f coeffs2
  where f n (N2 m) = n * m

binProp :: NonNegative Integer -> Property
binProp (NonNegative n) = n === integerFromBin (integerToBin n)

-- unsafe!
n3ton2 :: N3 -> N2
n3ton2 = N2 . unN3

n2ton3 :: N2 -> N3
n2ton3 = N3 . unN2

-- [Integer] <-> [N3]
integerListToN3List :: [Integer] -> [N3]
integerListToN3List = concatMap (++ [N3 2]) . map (map n2ton3 . integerToBin)

integerListFromN3List :: [N3] -> [Integer]
integerListFromN3List = init . map (integerFromBin . map n3ton2) . splitOn [N3 2]

n3ListProp :: [NonNegative Integer] -> Property
n3ListProp xs = xs' === integerListFromN3List (integerListToN3List xs')
  where xs' = map getNonNegative xs

-- Trick from user5402 answer
-- Integer <-> Sorted Integer
toDiffList :: [Integer] -> [Integer]
toDiffList = toDiffList' 0
  where toDiffList' _ [] = []
        toDiffList' p (x:xs) = x - p : toDiffList' x xs

fromDiffList :: [Integer] -> [Integer]
fromDiffList = fromDiffList' 0
  where fromDiffList' _ [] = []
        fromDiffList' p (x:xs) = p + x : fromDiffList' (x + p) xs

diffProp :: [Integer] -> Property
diffProp xs = xs === fromDiffList (toDiffList xs)

---

toIndex :: [Integer] -> Integer
toIndex = listToInteger . integerListToN3List . toDiffList

fromIndex :: Integer -> [Integer]
fromIndex = fromDiffList . integerListFromN3List . listFromInteger

-- [1,0] /= [0]
-- Decreasing sequence! doesn't terminate in this case
doesntHold :: [NonNegative Integer] -> Property
doesntHold xs = xs' === fromIndex (toIndex xs')
  where xs' = map getNonNegative xs

holds :: [NonNegative Integer] -> Property
holds xs = xs' === fromIndex (toIndex xs')
  where xs' = sort $ map getNonNegative xs

g :: ([Integer] -> Integer) -> [Integer] -> Integer
g mg = g' . sort
  where g' [] = 0
        g' (x:xs)  = x + sum (map mg $ tails xs)

g_tree :: Tree Integer
g_tree = fmap (g faster_g' . fromIndex) nats

faster_g' :: [Integer] -> Integer
faster_g' = index g_tree . toIndex

faster_g = faster_g' . sort

Second addition: 第二次补充:

I quickly benchmarked graph and binary sequence approach for my g with: 我快速为我的g基准图和二进制序列方法:

main :: IO ()
main = do
  n <- read . head <$> getArgs
  print $ faster_g [100, 110..n]

And the results are: 结果是:

% time ./IntegerMemo 1000
1225560638892526472150132981770
./IntegerMemo 1000  0.19s user 0.01s system 98% cpu 0.200 total
% time ./IntegerMemo 2000
3122858113354873680008305238045814042010921833620857170165770
./IntegerMemo 2000  1.83s user 0.05s system 99% cpu 1.888 total
% time ./IntegerMemo 2500
4399449191298176980662410776849867104410434903220291205722799441218623242250
./IntegerMemo 2500  3.74s user 0.09s system 99% cpu 3.852 total
% time ./IntegerMemo 3000    
5947985907461048240178371687835977247601455563536278700587949163642187584269899171375349770
./IntegerMemo 3000  6.66s user 0.13s system 99% cpu 6.830 total

% time ./IntegerMemoGrap 1000 
1225560638892526472150132981770
./IntegerMemoGrap 1000  0.10s user 0.01s system 97% cpu 0.113 total
% time ./IntegerMemoGrap 2000
3122858113354873680008305238045814042010921833620857170165770
./IntegerMemoGrap 2000  0.97s user 0.04s system 98% cpu 1.028 total
% time ./IntegerMemoGrap 2500
4399449191298176980662410776849867104410434903220291205722799441218623242250
./IntegerMemoGrap 2500  2.11s user 0.08s system 99% cpu 2.202 total
% time ./IntegerMemoGrap 3000 
5947985907461048240178371687835977247601455563536278700587949163642187584269899171375349770
./IntegerMemoGrap 3000  3.33s user 0.09s system 99% cpu 3.452 total

Looks like that graph version is faster by constant factor of 2 . 看起来图形版本的常数因子为2 But they seem to have same time complexity :) 但他们似乎有相同的时间复杂性:)

Looks like my problem is solved by simply replacing intSeq in the definition of g_graph by a monotone version: 看起来像我的问题通过简单地用单调版本替换intSeq定义中的g_graph来解决:

-- replace vertexes for non-monotone integer lists by the according monotone one
monoIntSeq :: Graph [Integer]
monoIntSeq = f intSeq
  where f (Graph as t) | as == sort as = Graph as $ fmap f t
                       | otherwise     = fetch monIntSeq $ sort as

-- extract the subgraph after following the given labels
fetch :: Graph a -> [Integer] -> Graph a
fetch g [] = g
fetch (Graph _ t) (x:xs) = fetch (index t x) xs

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) monoIntSeq

Many thanks to all (especially user5402 and Oleg) for the help! 非常感谢所有人(特别是user5402和Oleg)的帮助!


Edit: I still have the problem that the memory consumption is to high for my typical use case which can be described by following a path like this: 编辑:我仍然有一个问题,我的典型用例的内存消耗量很高,可以通过以下路径来描述:

p :: [Integer]
p = map f [1..]
  where f n | n `mod` 6 == 0 = n `div` 6
            | n `mod` 3 == 0 = n `div` 3
            | n `mod` 2 == 0 = n `div` 2
            | otherwise      = n

A slight improvement is to define the monotone integer sequences directly like this: 略有改进是直接定义单调整数序列,如下所示:

-- extract the subgraph after following the given labels (right to left)
fetch :: Graph a -> [Integer] -> Graph a
fetch = foldl' step
  where step (Graph _ t) n = index t n

-- walk the graph following the given labels (right to left)
walk :: Graph a -> [Integer] -> a
walk g ns = a
  where Graph a _ = fetch g ns

-- all monotone falling integer sequences
monoIntSeqs :: Graph [Integer]
monoIntSeqs = Graph [] $ fmap (flip f monoIntSeqs) nats
  where f n (Graph ns t) | null ns      = Graph (n:ns) $ fmap (f n) t
                         | n >= head ns = Graph (n:ns) $ fmap (f n) t
                         | otherwise    = fetch monoIntSeqs (insert' n ns)
        insert' = insertBy (comparing Down)

But at the end I might just use the original integer sequences without identification, identify nodes now and then explicitly and avoid keeping a reference to g_graph etc to let the garbage collection clean up as the program proceeds. 但最后我可能只使用没有标识的原始整数序列,现在然后明确地识别节点并避免保留对g_graph等的引用,以便在程序进行时清理垃圾收集。

Reading the functional pearl Trouble Shared is Trouble Halved by Richard Bird and Ralf Hinze, I understood how to implement, what I was looking for two years ago (again based on Edward Kmett's trick): 阅读功能珍珠麻烦共享是理查德伯德和拉尔夫欣泽的麻烦减半 ,我明白了如何实施,两年前我想要的东西(再次基于爱德华凯梅特的伎俩):

{-# LANGUAGE BangPatterns #-}
import Data.Function (fix)

data Tree a = Tree (Tree a) a (Tree a)
  deriving Show

instance Functor Tree where
  fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
  (q,0) -> index l q
  (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
  where go !n !s = Tree (go l s') n (go r s')
          where l = n + s
                r = l + s
                s' = s * 2

data IntSeqTree a = IntSeqTree a (Tree (IntSeqTree a))

val :: IntSeqTree a -> a
val (IntSeqTree a _) = a

step :: Integer -> IntSeqTree t -> IntSeqTree t
step n (IntSeqTree _ ts) = index ts n

intSeqTree :: IntSeqTree [Integer]
intSeqTree = fix $ create []
  where create p x = IntSeqTree p $ fmap (extend x) nats
        extend x n = case span (>n) (val x) of
                       ([], p) -> fix $ create (n:p)
                       (m, p)  -> foldr step intSeqTree (m ++ n:p)

instance Functor IntSeqTree where
  fmap f (IntSeqTree a t) = IntSeqTree (f a) (fmap (fmap f) t)

In my use case I have hundreds or thousands of similar integer sequences (of length few hundred entries) that are generated incrementally. 在我的用例中,我有数百或数千个类似的整数序列(长度为几百个条目),它们是递增生成的。 So for me this way is cheaper than sorting the sequences before looking up the function value (which I will access by using fmap on intSeqTree). 所以对我来说这种方式比在查找函数值之前对序列进行排序要便宜(我将通过在intSeqTree上使用fmap来访问它)。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM