简体   繁体   中英

Scala equivalent to Haskell pattern Matching on constructors

In Haskell I have got constructors defined like this :

data Fonction =
Const Float
    | Param String
    | Var String
    | Add Fonction Fonction

infix 3 @+

(@+) :: Fonction -> Fonction -> Fonction
(@+) (Const k) (Const k') = Const (k + k')

In Scala, quite new for me, I tried :

trait Function {
  case class Const(value: Double) extends Function
  case class Param(value: String) extends Function
  case class Var(value: String) extends Function
  case class Add(f1: Function, f2: Function) extends Function
}

object Op extends Function {
   def som(f1:Function, f2: Function): Function  = 
    (f1, f2) match {
       case (Const(a),Const(b)) => Const(a + b)
       case _  => Add(f1,f2)
   }  
}

object HelloWorld extends Function {
  def main(args: Array[String]): Unit = {
    val s =  Op.som(Const(3),Const(5))
    println(s)
   }
}

but the pattern (Const(3), Const(5)) isn't matched since I get Add(Const(3.0),Const(5.0))

Although @nicodp code is technically correct (but IMHO not quite), it provides no explanation. The reason why your code doesn't work as expect is because you define your cases inside trait Function . The thing is that in the Java world inner classes can be static and non-static and Scala copies this part of the behavior. And classes defined inside a trait are non-static. Particularly it means that an instance of HelloWorld.Const(42) is not equal to an instance of Op.Const(42) because they capture different "outer" values of different types (the Op object and the HelloWorld object). And this is also why your pattern matching doesn't work as you might expect. If you make you write your pattern matching as:

object Op extends Function {
  def som(f1: Function, f2: Function): Function =
    (f1, f2) match {
      // this doesn't work
      //case (Const(a), Const(b)) => Const(a + b) 
      // this works
      case (HelloWorld.Const(a), HelloWorld.Const(b)) => Const(a + b)
      case _ => Add(f1, f2)
    }
}

it will work but this is probably not what you want. The way to work this around is to define your case class es in a static context. You may do it in just top level as in @nicodp suggested. Another choice is to put them into a Function a companion object like this:

sealed trait Function 

object Function {
  case class Const(value: Double) extends Function
  case class Param(value: String) extends Function
  case class Var(value: String) extends Function
  case class Add(f1: Function, f2: Function) extends Function
}

object Op {

  import Function._

  def som(f1: Function, f2: Function): Function =
    (f1, f2) match {
      //          case (HelloWorld.Const(a), HelloWorld.Const(b)) => Const(a + b)
      case (Const(a), Const(b)) => Const(a + b)
      case _ => Add(f1, f2)
    }
}

object HelloWorld {
  import Function._

  def main(args: Array[String]): Unit = {
    val s = Op.som(Const(3), Const(5))
    println(s)
  }
}

Note that you replace extends Function with import Function._

Also note that it is suggested to use sealed on such traits so you (compiler) know your pattern-matching is exhaustive.

The types you have created are path dependent . Change the type declaration to this:

trait Function

object Function {
  case class Const(value: Double) extends Function
  case class Param(value: String) extends Function
  case class Var(value: String) extends Function
  case class Add(f1: Function, f2: Function) extends Function
}

import Function._

With your original defintion the Const need to be created inside of the same Function object to match. You are matching inside of Op , but the Const you are matching are created inside of HelloWorld , therefore you are matching HelloWorld.Const againt Op.Const .

Is there any reason why Op and HelloWorld are derived from the Function trait? (You do not use this in your example, but perhaps you have some reason out of that). Following would look more natural to me:

object HelloWorld {
  def som(f1:Function, f2: Function): Function  = {
    (f1, f2) match {
      case (Const(a), Const(b)) => Const(a + b)
      case _ => Add(f1, f2)
    }
  }

  def main(args: Array[String]): Unit = {
    val s =  som(Const(3),Const(5))
    println(s)
  }
}

To define this piece of Haskell in Scala :

data Fonction =
Const Float
    | Param String
    | Var String
    | Add Fonction Fonction

I'd proceed as follows:

trait Function

case class Const(v: Float) extends Function
case class Param(v: String) extends Function
case class Var(v: String) extends Function
case class Add(f1: Function, f2: Function) extends Function

And now your code will work as expected:

scala> HelloWorld.main(Array())
Const(8.0)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM