[英]What is the correct way to perform constant-space nested loops in Haskell?
在Haskell中有两种明显的“惯用”方法来执行嵌套循环:使用list monad或使用forM_
来替换传统的fors
。 我已经设置了一个基准来确定它们是否被编译为紧密循环:
import Control.Monad.Loop
import Control.Monad.Primitive
import Control.Monad
import Control.Monad.IO.Class
import qualified Data.Vector.Unboxed.Mutable as MV
import qualified Data.Vector.Unboxed as V
times = 100000
side = 100
-- Using `forM_` to replace traditional fors
test_a mvec =
forM_ [0..times-1] $ \ n -> do
forM_ [0..side-1] $ \ y -> do
forM_ [0..side-1] $ \ x -> do
MV.write mvec (y*side+x) 1
-- Using the list monad to replace traditional forms
test_b mvec = sequence_ $ do
n <- [0..times-1]
y <- [0..side-1]
x <- [0..side-1]
return $ MV.write mvec (y*side+x) 1
main = do
let vec = V.generate (side*side) (const 0)
mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
-- test_a mvec
-- test_b mvec
vec' <- V.unsafeFreeze mvec :: IO (V.Vector Int)
print $ V.sum vec'
此测试创建一个100x100向量,使用嵌套循环向每个索引写入1并重复100k次。 编译那些刚刚ghc -O2 test.hs -o test
(GHC版本7.8.4),结果是: 3.853s
的forM_
版本和10.460s
的list monad
。 为了提供参考,我还用JavaScript编写了这个测试:
var side = 100;
var times = 100000;
var vec = [];
for (var i=0; i<side*side; ++i)
vec.push(0);
for (var n=0; n<times; ++n)
for (var y=0; y<side; ++y)
for (var x=0; x<side; ++x)
vec[x+y*side] = 1;
var s = 0;
for (var i=0; i<side*side; ++i)
s += vec[i];
console.log(s);
这个等效的JavaScript程序需要1s
来完成,击败Haskell的未装箱的向量,这是不寻常的,这表明Haskell没有在恒定空间中运行循环,而是进行分配。 然后我发现了一个声称提供类型保证紧密循环的库Control.Monad.Loop
:
-- Using `for` from Control.Monad.Loop
test_c mvec = exec_ $ do
n <- for 0 (< times) (+ 1)
x <- for 0 (< side) (+ 1)
y <- for 0 (< side) (+ 1)
liftIO (MV.write mvec (y*side+x) 1)
哪个在1s
运行。 但是,这个库并不是很常用,而且远非惯用,因此, 获得快速恒定空间二维计算的惯用方法是什么? (注意这不是REPA的情况,因为我想在网格上执行任意IO操作。)
用GHC编写严格的变异代码有时会很棘手。 我打算写几件不同的东西,可能是一种比我更喜欢的漫无边际的方式。
对于初学者,我们应该在任何情况下使用GHC 7.10, 否则 forM_
和list monad解决方案永远不会融合。
另外,我用MV.write
替换了MV.unsafeWrite
,部分原因是因为它更快,但更重要的是它减少了生成的Core中的一些混乱。 从现在开始,运行时统计信息引用带有unsafeWrite
代码。
即使使用GHC 7.10,我们也应该首先注意所有那些[0..times-1]
和[0..side-1]
表达式,因为如果我们不采取必要的步骤,它们每次都会破坏性能。 问题是它们是常量范围,并且-ffull-laziness
(默认情况下在-O
上启用) -ffull-laziness
它们浮动到顶层。 这可以防止列表融合,并且迭代Int#
范围比迭代盒装Int
-s列表便宜,所以这是一个非常糟糕的优化。
让我们在几秒钟内看到一些运行时间(不使用unsafeWrite
)代码。 ghc -O2 -fllvm
,我使用+RTS -s
进行计时。
test_a: 1.6
test_b: 6.2
test_c: 0.6
对于GHC Core查看,我使用了ghc -O2 -ddump-simpl -dsuppress-all -dno-suppress-type-signatures
。
在test_a
的情况下, [0..99]
范围被取消:
main4 :: [Int]
main4 = eftInt 0 99 -- means "enumFromTo" for Int.
虽然最外面的[0..9999]
循环被融合成一个尾递归帮助器:
letrec {
a3_s7xL :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
a3_s7xL =
\ (x_X5zl :: Int#) (s1_X4QY :: State# RealWorld) ->
case a2_s7xF 0 s1_X4QY of _ { (# ipv2_a4NA, ipv3_a4NB #) ->
case x_X5zl of wild_X1S {
__DEFAULT -> a3_s7xL (+# wild_X1S 1) ipv2_a4NA;
99999 -> (# ipv2_a4NA, () #)
}
}; }
在test_b
的情况下,再次仅提升[0..99]
。 但是, test_b
要慢得多,因为它必须构建和排序实际的[IO ()]
列表。 至少GHC足够明智,只能为两个内部循环构建单个[IO ()]
,然后对其进行10000
次排序。
let {
lvl7_s4M5 :: [IO ()]
lvl7_s4M5 = -- omitted
letrec {
a2_s7Av :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
a2_s7Av =
\ (x_a5xi :: Int#) (eta_B1 :: State# RealWorld) ->
letrec {
a3_s7Au
:: [IO ()] -> State# RealWorld -> (# State# RealWorld, () #)
a3_s7Au =
\ (ds_a4Nu :: [IO ()]) (eta1_X1c :: State# RealWorld) ->
case ds_a4Nu of _ {
[] ->
case x_a5xi of wild1_X1y {
__DEFAULT -> a2_s7Av (+# wild1_X1y 1) eta1_X1c;
99999 -> (# eta1_X1c, () #)
};
: y_a4Nz ys_a4NA ->
case (y_a4Nz `cast` ...) eta1_X1c
of _ { (# ipv2_a4Nf, ipv3_a4Ng #) ->
a3_s7Au ys_a4NA ipv2_a4Nf
}
}; } in
a3_s7Au lvl7_s4M5 eta_B1; } in
-- omitted
我们该如何解决这个问题? 我们可以用{-# OPTIONS_GHC -fno-full-laziness #-}
解决这个问题。 在我们的案例中,这确实有很大帮助:
test_a: 0.5
test_b: 0.48
test_c: 0.5
或者,我们可以摆弄INLINE
编曲马。 在浮动完成后显然内联函数可以保持良好的性能。 我发现即使没有编译指示,GHC也会内联我们的测试函数,但是显式编译指示会导致它仅在浮动后才能内联。 例如,如果没有-fno-full-laziness
,这会产生良好的性能:
test_a mvec =
forM_ [0..times-1] $ \ n ->
forM_ [0..side-1] $ \ y ->
forM_ [0..side-1] $ \ x ->
MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE test_a #-}
但过早地内联会导致性能不佳:
test_a mvec =
forM_ [0..times-1] $ \ n ->
forM_ [0..side-1] $ \ y ->
forM_ [0..side-1] $ \ x ->
MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE [~2] test_a #-} -- "inline before the first phase please"
这种INLINE
解决方案的问题在于,面对GHC的浮动冲击,它相当脆弱。 例如,手动内联不会保留性能。 下面的代码很慢,因为类似于INLINE [~2]
它给GHC一个浮出的机会:
main = do
let vec = V.generate (side*side) (const 0)
mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
forM_ [0..times-1] $ \ n ->
forM_ [0..side-1] $ \ y ->
forM_ [0..side-1] $ \ x ->
MV.unsafeWrite mvec (y*side+x) 1
那我们该怎么办?
首先,我认为对于那些想要编写高性能代码并且知道自己在做什么的人来说,使用-fno-full-laziness
是一个完全可行的,甚至是更好的选择。 例如,它用于unordered-containers
。 有了它,我们可以更精确地控制共享,我们可以随时手动浮出或内联。
对于更常规的代码,我相信使用Control.Monad.Loop
或任何其他提供该功能的包没有任何问题。 许多Haskell用户对依赖小型“边缘”库并不是一丝不苟。 我们也可以只是重新实现for
,在所需的通用性。 例如,以下表现与其他解决方案一样好:
for :: Monad m => a -> (a -> Bool) -> (a -> a) -> (a -> m ()) -> m ()
for init while step body = go init where
go !i | while i = body i >> go (step i)
go i = return ()
{-# INLINE for #-}
我最初对堆分配的+RTS -s
数据感到非常困惑。 test_a
分配的非平凡与-fno-full-laziness
,也test_c
没有完全懒惰,而这些拨款与数量线性缩放times
迭代,但test_b
全懒惰只为矢量分配:
-- with -fno-full-laziness, no INLINE pragmas
test_a: 242,521,008 bytes
test_b: 121,008 bytes
test_c: 121,008 bytes -- but 240,120,984 with full laziness!
此外,在这种情况下, test_c
INLINE
编译指示根本没有帮助。
我花了一些时间试图在Core中为相关程序找到堆分配的迹象,但没有成功,直到实现让我感到震惊:GHC堆栈帧在堆上,包括主线程的帧,以及正在执行的函数堆分配基本上是在最多三个堆栈帧中运行三次嵌套循环。 由+RTS -s
注册的堆分配只是堆栈帧的持续弹出和推送。
对于以下代码,这在Core中非常明显:
{-# OPTIONS_GHC -fno-full-laziness #-}
-- ...
test_a mvec =
forM_ [0..times-1] $ \ n ->
forM_ [0..side-1] $ \ y ->
forM_ [0..side-1] $ \ x ->
MV.unsafeWrite mvec (y*side+x) 1
main = do
let vec = V.generate (side*side) (const 0)
mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
test_a mvec
我在这里荣耀归于此。 随意跳过。
main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
\ (s_a5HK :: State# RealWorld) ->
case divInt# 9223372036854775807 8 of ww4_a5vr { __DEFAULT ->
-- start of vector creation ----------------------
case tagToEnum# (># 10000 ww4_a5vr) of _ {
False ->
case newByteArray# 80000 (s_a5HK `cast` ...)
of _ { (# ipv_a5fv, ipv1_a5fw #) ->
letrec {
$s$wa_s8jS
:: Int#
-> Int#
-> State# (PrimState IO)
-> (# State# (PrimState IO), Int #)
$s$wa_s8jS =
\ (sc_s8jO :: Int#)
(sc1_s8jP :: Int#)
(sc2_s8jR :: State# (PrimState IO)) ->
case tagToEnum# (<# sc1_s8jP 10000) of _ {
False -> (# sc2_s8jR, I# sc_s8jO #);
True ->
case writeIntArray# ipv1_a5fw sc_s8jO 0 (sc2_s8jR `cast` ...)
of s'#_a5Gn { __DEFAULT ->
$s$wa_s8jS (+# sc_s8jO 1) (+# sc1_s8jP 1) (s'#_a5Gn `cast` ...)
}
}; } in
case $s$wa_s8jS 0 0 (ipv_a5fv `cast` ...)
-- end of vector creation -------------------
of _ { (# ipv6_a4Hv, ipv7_a4Hw #) ->
letrec {
a2_s7MJ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
a2_s7MJ =
\ (x_a5Ho :: Int#) (eta_B1 :: State# RealWorld) ->
letrec {
a3_s7ME :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
a3_s7ME =
\ (x1_X5Id :: Int#) (eta1_XR :: State# RealWorld) ->
case ipv7_a4Hw of _ { I# dt4_a5x6 ->
case writeIntArray#
(ipv1_a5fw `cast` ...) (*# x1_X5Id 100) 1 (eta1_XR `cast` ...)
of s'#_a5Gn { __DEFAULT ->
letrec {
a4_s7Mz :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
a4_s7Mz =
\ (x2_X5J8 :: Int#) (eta2_X1U :: State# RealWorld) ->
case writeIntArray#
(ipv1_a5fw `cast` ...)
(+# (*# x1_X5Id 100) x2_X5J8)
1
(eta2_X1U `cast` ...)
of s'#1_X5Hf { __DEFAULT ->
case x2_X5J8 of wild_X2o {
__DEFAULT -> a4_s7Mz (+# wild_X2o 1) (s'#1_X5Hf `cast` ...);
99 -> (# s'#1_X5Hf `cast` ..., () #)
}
}; } in
case a4_s7Mz 1 (s'#_a5Gn `cast` ...)
of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
case x1_X5Id of wild_X1e {
__DEFAULT -> a3_s7ME (+# wild_X1e 1) ipv2_a4QH;
99 -> (# ipv2_a4QH, () #)
}
}
}
}; } in
case a3_s7ME 0 eta_B1 of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
case x_a5Ho of wild_X1a {
__DEFAULT -> a2_s7MJ (+# wild_X1a 1) ipv2_a4QH;
99999 -> (# ipv2_a4QH, () #)
}
}; } in
a2_s7MJ 0 (ipv6_a4Hv `cast` ...)
}
};
True ->
case error
(unpackAppendCString#
"Primitive.basicUnsafeNew: length to large: "#
(case $wshowSignedInt 0 10000 ([])
of _ { (# ww5_a5wm, ww6_a5wn #) ->
: ww5_a5wm ww6_a5wn
}))
of wild_00 {
}
}
}
main :: IO ()
main = main1 `cast` ...
main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)
main :: IO ()
main = main2 `cast` ...
我们还可以通过以下方式很好地演示帧的分配。 让我们改变test_a
:
test_a mvec =
forM_ [0..times-1] $ \ n ->
forM_ [0..side-1] $ \ y ->
forM_ [0..side-50] $ \ x -> -- change here
MV.unsafeWrite mvec (y*side+x) 1
现在堆分配保持完全相同,因为最内层的循环是尾递归并使用单个帧。 通过以下更改,堆分配减半(到124,921,008字节),因为我们推送和弹出一半的帧:
test_a mvec =
forM_ [0..times-1] $ \ n ->
forM_ [0..side-50] $ \ y -> -- change here
forM_ [0..side-1] $ \ x ->
MV.unsafeWrite mvec (y*side+x) 1
test_b
和test_c
(没有完全懒惰)而是编译为在单个堆栈帧内使用嵌套case构造的代码,并遍历索引以查看哪个应该递增。 有关以下main
请参阅Core:
{-# LANGUAGE BangPatterns #-} -- later I'll talk about this
{-# OPTIONS_GHC -fno-full-laziness #-}
main = do
let vec = V.generate (side*side) (const 0)
!mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
test_c mvec
瞧:
main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
\ (s_a5Iw :: State# RealWorld) ->
case divInt# 9223372036854775807 8 of ww4_a5vT { __DEFAULT ->
-- start of vector creation ----------------------
case tagToEnum# (># 10000 ww4_a5vT) of _ {
False ->
case newByteArray# 80000 (s_a5Iw `cast` ...)
of _ { (# ipv_a5g3, ipv1_a5g4 #) ->
letrec {
$s$wa_s8ji
:: Int#
-> Int#
-> State# (PrimState IO)
-> (# State# (PrimState IO), Int #)
$s$wa_s8ji =
\ (sc_s8je :: Int#)
(sc1_s8jf :: Int#)
(sc2_s8jh :: State# (PrimState IO)) ->
case tagToEnum# (<# sc1_s8jf 10000) of _ {
False -> (# sc2_s8jh, I# sc_s8je #);
True ->
case writeIntArray# ipv1_a5g4 sc_s8je 0 (sc2_s8jh `cast` ...)
of s'#_a5GP { __DEFAULT ->
$s$wa_s8ji (+# sc_s8je 1) (+# sc1_s8jf 1) (s'#_a5GP `cast` ...)
}
}; } in
case $s$wa_s8ji 0 0 (ipv_a5g3 `cast` ...)
of _ { (# ipv6_a4MX, ipv7_a4MY #) ->
case ipv7_a4MY of _ { I# dt4_a5xy ->
-- end of vector creation
letrec {
a2_s7Q6 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
a2_s7Q6 =
\ (x_a5HT :: Int#) (eta_B1 :: State# RealWorld) ->
letrec {
a3_s7Q5 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
a3_s7Q5 =
\ (x1_X5J9 :: Int#) (eta1_XP :: State# RealWorld) ->
letrec {
a4_s7MZ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
a4_s7MZ =
\ (x2_X5Jl :: Int#) (s1_X4Xb :: State# RealWorld) ->
case writeIntArray#
(ipv1_a5g4 `cast` ...)
(+# (*# x1_X5J9 100) x2_X5Jl)
1
(s1_X4Xb `cast` ...)
of s'#_a5GP { __DEFAULT ->
-- the interesting part! ------------------
case x2_X5Jl of wild_X1y {
__DEFAULT -> a4_s7MZ (+# wild_X1y 1) (s'#_a5GP `cast` ...);
99 ->
case x1_X5J9 of wild1_X1o {
__DEFAULT -> a3_s7Q5 (+# wild1_X1o 1) (s'#_a5GP `cast` ...);
99 ->
case x_a5HT of wild2_X1c {
__DEFAULT -> a2_s7Q6 (+# wild2_X1c 1) (s'#_a5GP `cast` ...);
99999 -> (# s'#_a5GP `cast` ..., () #)
}
}
}
}; } in
a4_s7MZ 0 eta1_XP; } in
a3_s7Q5 0 eta_B1; } in
a2_s7Q6 0 (ipv6_a4MX `cast` ...)
}
}
};
True ->
case error
(unpackAppendCString#
"Primitive.basicUnsafeNew: length to large: "#
(case $wshowSignedInt 0 10000 ([])
of _ { (# ww5_a5wO, ww6_a5wP #) ->
: ww5_a5wO ww6_a5wP
}))
of wild_00 {
}
}
}
main :: IO ()
main = main1 `cast` ...
main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)
main :: IO ()
main = main2 `cast` ...
我不得不承认,我基本上不知道为什么有些代码会避免堆栈帧的创建而有些代码却没有。 我怀疑从“内部”输出内联有帮助,并且快速检查告诉我Control.Monad.Loop
使用CPS编码,这可能与此相关,尽管Monad.Loop
解决方案对于让浮动很敏感,我不能在核心的短时间内确定为什么test_c
with let floating无法在单个堆栈帧中运行。
现在,在单个堆栈帧中运行的性能优势很小。 我们已经看到test_b
只比test_a
略快。 我把这个绕道包括在答案中因为我发现它有启发性。
所谓的国家黑客行为使GHC积极参与IO和ST行动。 我想我应该在这里提一下,因为除了让浮动这是另一件可以彻底破坏性能的事情。
状态hack启用了优化-O
,并且可能会渐进地减慢程序的速度。 Reid Barton的一个简单例子:
import Control.Monad
import Debug.Trace
expensive :: String -> String
expensive x = trace "$$$" x
main :: IO ()
main = do
str <- fmap expensive getLine
replicateM_ 3 $ print str
使用GHC-7.10.2,这将打印"$$$"
一次,不进行优化,但使用-O2
三次。 似乎用GHC-7.10,我们无法摆脱-fno-state-hack
(这是来自Reid Barton的链接票证的主题)的这种行为。
严格的monadic绑定可靠地摆脱了这个问题:
main :: IO ()
main = do
!str <- fmap expensive getLine
replicateM_ 3 $ print str
我认为在IO和ST中进行严格绑定是个好习惯。 我有一些经验(虽然不是明确的;我不是GHC专家),如果我们使用-fno-full-laziness
则特别需要严格的绑定。 显然,完全懒惰可以帮助摆脱由国家黑客引起的内联引入的一些工作重复; 使用test_b
并且没有完全懒惰,省略严格绑定!mvec <- V.unsafeThaw vec
导致轻微的减速和非常难看的Core输出。
根据我的经验, forM_ [0..n-1]
可以表现良好,但遗憾的是它并不可靠。 只需将一个INLINE
编译指示添加到test_a
并使用-O2
使其运行速度更快(对我来说为4s到1s),但手动内联它(复制粘贴)会再次降低速度。
更可靠的功能是for
从statistics
其作为实施
-- | Simple for loop. Counts from /start/ to /end/-1.
for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for n0 !n f = loop n0
where
loop i | i == n = return ()
| otherwise = f i >> loop (i+1)
{-# INLINE for #-}
使用它看起来类似于forM_
with lists:
test_d :: MV.IOVector Int -> IO ()
test_d mv =
for 0 times $ \_ ->
for 0 side $ \i ->
for 0 side $ \j ->
MV.unsafeWrite mv (i*side + j) 1
但是执行得非常好(对我来说是0.85秒)而没有任何分配列表的风险。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.