简体   繁体   中英

Composing Stateful functions in Haskell

What is the simplest Haskell library that allows composition of stateful functions?

We can use the State monad to compute a stock's exponentially-weighted moving average as follows:

import Control.Monad.State.Lazy
import Data.Functor.Identity

type StockPrice = Double
type EWMAState = Double
type EWMAResult = Double

computeEWMA :: Double -> StockPrice -> State EWMAState EWMAResult
computeEWMA α price = do oldEWMA <- get
                         let newEWMA = α * oldEWMA + (1.0 - α) * price
                         put newEWMA
                         return newEWMA

However, it's complicated to write a function that calls other stateful functions. For example, to find all data points where the stock's short-term average crosses its long-term average, we could write:

computeShortTermEWMA = computeEWMA 0.2
computeLongTermEWMA  = computeEWMA 0.8

type CrossingState = Bool
type GoldenCrossState = (CrossingState, EWMAState, EWMAState)
checkIfGoldenCross :: StockPrice -> State GoldenCrossState String
checkIfGoldenCross price = do (oldCrossingState, oldShortState, oldLongState) <- get
                              let (shortEWMA, newShortState) = runState (computeShortTermEWMA price) oldShortState
                              let (longEWMA, newLongState) = runState (computeLongTermEWMA price) oldLongState
                              let newCrossingState = (shortEWMA < longEWMA)
                              put (newCrossingState, newShortState, newLongState)
                              return (if newCrossingState == oldCrossingState then
                                    "no cross"
                                else
                                    "golden cross!")

Since checkIfGoldenCross calls computeShortTermEWMA and computeLongTermEWMA, we must manually wrap/unwrap their states.

Is there a more elegant way?

If I understood your code correctly, you don't share state between the call to computeShortTermEWMA and computeLongTermEWMA . They're just two entirely independent functions which happen to use state internally themselves. In this case, the elegant thing to do would be to encapsulate runState in the definitions of computeShortTermEWMA and computeLongTermEWMA , since they're separate self-contained entities:

computeShortTermEWMA start price = runState (computeEWMA 0.2 price) start

All this does is make the call site a bit neater though; I just moved the runState into the definition. This marks the state a local implementation detail of computing the EWMA, which is what it really is. This is underscored by the way GoldenCrossState is a different type from EWMAState .

In other words, you're not really composing stateful functions; rather, you're composing functions that happen to use state inside. You can just hide that detail.

More generally, I don't really see what you're using the state for at all. I suppose you would use it to iterate through the stock price, maintaining the EWMA. However, I don't think this is necessarily the best way to do it. Instead, I would consider writing your EWMA function over a list of stock prices, using something like a scan. This should make your other analysis functions easier to implement, since they'll just be list functions as well. (In the future, if you need to deal with IO, you can always switch over to something like Pipes which presents an interface really similar to lists.)

In this particular case, you have a y -> (a, y) and a z -> (b, z) that you want to use to compose a (x, y, z) -> (c, (x, y, z)) . Having never used lens before, this seems like a perfect opportunity.

In general, we can promote a stateful operations on a sub-state to operate on the whole state like this:

promote :: Lens' s s' -> StateT s' m a -> StateT s m a
promote lens act = do
    big <- get
    let little = view lens big
        (res, little') = runState act little
        big' = set lens little' big
    put big'
    return res
-- Feel free to golf and optimize, but this is pretty readable.

Our lens a witness that s' is a sub-state of s .

I don't know if "promote" is a good name, and I don't recall seeing this function defined elsewhere (but it's probably already in lens ).

The witnesses you need are named _2 and _3 in lens so, you could change a couple of lines of code to look like:

shortEWMA <- promote _2 (computeShortTermEWMA price)
longEWMA <- promote _3 (computeLongTermEWMA price)

If a Lens allows you to focus on inner values, maybe this combinator should be called blurredBy (for prefix application) or obscures (for infix application).

There is really no need to use any monad at all for these simple functions. You're (ab)using the State monad to calculate a one-off result in computeEWMA when there is no state involved. The only line that is actually important is the formula for EWMA, so let's pull that into it's own function.

ewma :: Double -> Double -> Double -> Double
ewma a price t = a * t + (1 - a) * price

If you inline the definition of State and ignore the String values, this next function has almost the exact same signature as your original checkIfGoldenCross !

type EWMAState = (Bool, Double, Double)

ewmaStep :: Double -> EWMAState -> EWMAState
ewmaStep price (crossing, short, long) =
    (crossing == newCrossing, newShort, newLong)
    where newCrossing = newShort < newLong
          newShort = ewma 0.2 price short
          newLong  = ewma 0.8 price long

Although it doesn't use the State monad, we're certainly dealing with state here. ewmaStep takes a stock price, the old EWMAState and returns a new EWMAState .

Now putting it all together with scanr :: (a -> b -> b) -> b -> [a] -> [b]

-- a list of stock prices
prices = [1.2, 3.7, 2.8, 4.3]

_1 (a, _, _) = a

main = print . map _1 $ scanr ewmaStep (False, 0, 0) prices
-- [False, True, False, True, False]

Because fold* and scan* use the cumulative result of previous values to compute each successive one, they are "stateful" enough that they can often be used in cases like this.

With a little type class magic, monad transformers allow you to have nested transformers of the same type. First, you will need a new instance for MonadState :

{-# LANGUAGE 
    UndecidableInstances 
  , OverlappingInstances
  #-}

instance (MonadState s m, MonadTrans t, Monad (t m)) => MonadState s (t m) where 
  state f = lift (state f)

Then you must define your EWMAState as a newtype, tagged with the type of term (alternatively, it could be two different types - but using a phantom type as a tag has its advantages):

data Term = ShortTerm | LongTerm 
type StockPrice = Double
newtype EWMAState (t :: Term) = EWMAState Double
type EWMAResult = Double
type CrossingState = Bool

Now, computeEWMA works on an EWMASTate which is polymorphic in term (the afformentioned example of tagging with phantom types), and in monad:

computeEWMA :: (MonadState (EWMAState t) m) => Double -> StockPrice -> m EWMAResult
computeEWMA a price = do 
  EWMAState old <- get
  let new =  a * old + (1.0 - a) * price
  put $ EWMAState new
  return new

For specific instances, you give them monomorphic type signatures:

computeShortTermEWMA :: (MonadState (EWMAState ShortTerm) m) => StockPrice -> m EWMAResult
computeShortTermEWMA = computeEWMA 0.2

computeLongTermEWMA :: (MonadState (EWMAState LongTerm) m) => StockPrice -> m EWMAResult
computeLongTermEWMA  = computeEWMA 0.8

Finally, your function:

checkIfGoldenCross :: 
  ( MonadState (EWMAState ShortTerm) m
  , MonadState (EWMAState LongTerm) m
  , MonadState CrossingState m) => 
  StockPrice -> m String 

checkIfGoldenCross price = do 
  oldCrossingState <- get
  shortEWMA <- computeShortTermEWMA price 
  longEWMA <- computeLongTermEWMA price 
  let newCrossingState = shortEWMA < longEWMA
  put newCrossingState
  return (if newCrossingState == oldCrossingState then "no cross" else "golden cross!")

The only downside is you have to explicitly give a type signature - in fact, the instance we introduced at the beginning has ruined all hopes of good type errors and type inference for cases where you have multiple copies of the same transformer in a stack.

Then a small helper function:

runState3 :: StateT a (StateT b (State c)) x -> a -> b -> c -> ((a , b , c) , x)
runState3 sa a b c = ((a' , b', c'), x) where 
  (((x, a'), b'), c') = runState (runStateT (runStateT sa a) b) c 

and:

>runState3 (checkIfGoldenCross 123) (shortTerm 123) (longTerm 123) True
((EWMAState 123.0,EWMAState 123.0,False),"golden cross!")

>runState3 (checkIfGoldenCross 123) (shortTerm 456) (longTerm 789) True
((EWMAState 189.60000000000002,EWMAState 655.8000000000001,True),"no cross")

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM