[英]Abnormally slow Haskell code
在F#中完成它之后,我一直在尝试使用Haskell中的Digits-Recognizer Dojo。 我得到了结果,但由于某种原因,我的Haskell代码非常慢,我似乎无法找到错误。
这是我的代码( .csv
文件可以在Dojo的GitHub上找到):
import Data.Char
import Data.List
import Data.List.Split
import Data.Ord
import System.IO
type Pixels = [Int]
data Digit = Digit { label :: Int, pixels :: Pixels }
distance :: Pixels -> Pixels -> Float
distance d1 d2 = sqrt . sum $ map pointDistance $ zip d1 d2
where pointDistance (a, b) = fromIntegral $ (a - b) * (a - b)
parseDigit :: String -> Digit
parseDigit s = Digit label pixels
where (label:pixels) = map read $ splitOn "," s
identify :: Digit -> [Digit] -> (Digit, Float)
identify digit training = minimumBy (comparing snd) distances
where distances = map fn training
fn ref = (ref, distance (pixels digit) (pixels ref))
readDigits :: String -> IO [Digit]
readDigits filename = do
fileContent <- readFile filename
return $ map parseDigit $ tail $ lines fileContent
main :: IO ()
main = do
trainingSample <- readDigits "trainingsample.csv"
validationSample <- readDigits "validationsample.csv"
let result = [(d, identify d trainingSample) | d <- validationSample]
fmt (d, (ref, dist)) = putStrLn $ "Found..."
mapM_ fmt result
这些糟糕表现的原因是什么?
[更新]感谢您的许多想法! 我已经将String
使用切换为Data.Text
并按照建议将我的List
用于Data.Vector
,遗憾的是结果仍然不尽如人意。
我的更新代码可在此处获得 。
为了让您更好地理解我的审讯,这里是我的Haskell(左)和F#(右)实现的输出。 我是这两种语言的全新手,所以我真诚地相信我的Haskell版本中的一个重大错误就是要慢得多。
如果你有耐心,你会注意到第二个结果的计算速度比第一个快得多。 那是因为您的实现需要一些时间来读取csv文件。
您可能会想要粘贴一个print语句来查看它何时完成加载:
main = do
trainingSample <- readDigits "trainingsample.csv"
validationSample <- readDigits "validationsample.csv"
putStrLn "done loading data"
但由于懒惰,这不会做你认为它做的事情。 trainingSample
和validationSample
尚未完全评估。 所以你的print语句几乎会立即打印出来,第一个结果仍然需要永远。
但是,您可以强制readDigits
完全评估它们的返回值,这样可以更好地了解在那里花费了多少时间。 您可以切换到使用非惰性IO,或只打印从数据派生的内容:
readDigits :: String -> IO [Digit]
readDigits filename = do
fileContent <- readFile filename
putStr' $ filename ++ ": "
rows <- forM (tail $ lines fileContent) $ \line -> do
let xs = parseDigit line
putStr' $ case compare (sum $ pixels xs) 0 of
LT -> "-"
EQ -> "0"
GT -> "+"
return xs
putStrLn ""
return rows
where putStr' s = putStr s >> hFlush stdout
在我的机器上,这让我看到完全读取trainingsample.csv
的数字需要大约27秒。
这是printf风格的分析,它不是很好(使用真正的分析器要好得多,或者使用标准来对代码的各个部分进行基准测试),但是对于这些目的来说足够好。
这显然是经济放缓的主要部分,所以值得尝试转向严格的io。 使用严格的Data.Text.IO.readFile
,将其减少到约18秒。
UPDATE
以下是加速更新代码的方法:
使用未装箱的矢量用于Pixels
(小赢):
import qualified Data.Vector.Unboxed as U -- ... type Pixels = U.Vector Int -- ... distance :: Pixels -> Pixels -> Float distance d1 d2 = sqrt . U.sum $ U.zipWith pointDistance d1 d2 where pointDistance ab = fromIntegral $ (a - b) * (a - b) parseDigit :: T.Text -> Digit parseDigit s = Digit label (U.fromList pixels) where (label:pixels) = map toDigit $ T.splitOn (T.pack ",") s toDigit s = either (\\_ -> 0) fst (T.Read.decimal s)
使用seq
(大赢)尽早强制进行距离评估:
identify :: Digit -> V.Vector Digit -> (Digit, Float) identify digit training = V.minimumBy (comparing snd) distances where distances = V.map fn training fn ref = let d = distance (pixels digit) (pixels ref) in d `seq` (ref, d)
在我的机器上,整个程序现在运行〜5s:
% ghc --make -O2 Main.hs
[1 of 1] Compiling Main ( Main.hs, Main.o )
Linking Main ...
% time ./Main
./Main 5.00s user 0.11s system 99% cpu 5.115 total
那些th are杀了你。
您的Vector版本,部分-O2 -fllvm
箱,适用于ByteString并使用-O2 -fllvm
编译,在我的机器上运行8秒钟:
import Data.Ord
import Data.Maybe
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
type Pixels = U.Vector Int
data Digit = Digit { label :: !Int, pixels :: !Pixels }
distance :: Pixels -> Pixels -> Float
distance d1 d2 = sqrt . U.sum . U.zipWith pointDistance d1 $ d2
where pointDistance a b = fromIntegral $ (a - b) * (a - b)
parseDigit :: B.ByteString -> Digit
parseDigit bs =
let (label:pixels) = toIntegers bs []
in Digit label (U.fromList pixels)
where
toIntegers bs is =
let Just (i,bs') = BC.readInt bs
in if B.null bs' then reverse is else toIntegers (BC.tail bs') (i:is)
identify :: Digit -> V.Vector Digit -> (Digit, Float)
identify digit training = V.minimumBy (comparing snd) distances
where distances = V.map fn training
fn ref = (ref, distance (pixels digit) (pixels ref))
readDigits :: String -> IO (V.Vector Digit)
readDigits filename = do
fileContent <- B.readFile filename
return . V.map parseDigit . V.fromList . tail . BC.lines $ fileContent
main :: IO ()
main = do
trainingSample <- readDigits "trainingsample.csv"
validationSample <- readDigits "validationsample.csv"
let result = V.map (\d -> (d, identify d trainingSample)) validationSample
fmt (d, (ref, dist)) = putStrLn $ "Found " ++ show (label ref) ++ " for " ++ show (label d) ++ " (distance=" ++ show dist ++ ")"
V.mapM_ fmt result
输出+RTS -s
:
989,632,984 bytes allocated in the heap
19,875,368 bytes copied during GC
31,016,504 bytes maximum residency (5 sample(s))
22,748,608 bytes maximum slop
78 MB total memory in use (1 MB lost due to fragmentation)
Tot time (elapsed) Avg pause Max pause
Gen 0 1761 colls, 0 par 0.05s 0.05s 0.0000s 0.0008s
Gen 1 5 colls, 0 par 0.00s 0.02s 0.0030s 0.0085s
INIT time 0.00s ( 0.00s elapsed)
MUT time 7.42s ( 7.69s elapsed)
GC time 0.05s ( 0.06s elapsed)
EXIT time 0.00s ( 0.01s elapsed)
Total time 7.47s ( 7.77s elapsed)
%GC time 0.7% (0.8% elapsed)
Alloc rate 133,419,569 bytes per MUT second
Productivity 99.3% of total user, 95.5% of total elapsed
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.