简体   繁体   中英

Compile-time check for vector dimension

I am implementing some lightweight mathematical vectors in scala. I would like to use the type system to check vector compatibility at compile time. For example, trying to add a vector of dimension 2 to another vector of dimension 3 should result in a compile error.

So far, I defined dimensions as case classes:

sealed trait Dim
case class One() extends Dim
case class Two() extends Dim
case class Three() extends Dim
case class Four() extends Dim
case class Five() extends Dim

And here is the vectors definition:

class Vec[D <: Dim](val values: Vector[Double]) {

  def apply(i: Int) = values(i)

  def *(k: Double) = new Vec[D]( values.map(_*k) )

  def +(that: Vec[D]) = {
    val newValues = ( values zip that.values ) map { 
      pair => pair._1 + pair._2
    }
    new Vec[D](newValues)
  }

  override lazy val toString = "Vec(" + values.mkString(", ") + ")"

}

This solution works well, however I have two concerns:

  • How can I add a dimension():Int method that returns the dimension (ie. 3 for a Vec[Three] )?

  • How can I handle higher dimensions without declaring all the needed case classes in advance ?

PS: I know there are nice existing mathematical vector libs, I am just trying to improve my scala understanding.

My suggestions:

I's suggest something like this:

sealed abstract class Dim(val dimension:Int)

object Dim {
  class One extends Dim(1)
  class Two extends Dim(2)
  class Three extends Dim(3)

  implicit object One extends One
  implicit object Two extends Two
  implicit object Three extends Three
}

case class Vec[D <: Dim](values: Vector[Double])(implicit dim:D) {

  require(values.size == dim.dimension)

  def apply(i: Int) = values(i)

  def *(k: Double) = Vec[D]( values.map(_*k) )

  def +(that: Vec[D]) = Vec[D](
     ( values zip that.values ) map {
      pair => pair._1 + pair._2
  })

  override lazy val toString = values.mkString("Vec(",", ",")")
}

Of course you can get only a runtime check on the vector length that way, but as others pointed already out you need something like Church numerals or other typelevel programming techniques to achieve compile time checks.

  import Dim._
  val a = Vec[Two](Vector(1.0,2.0))
  val b = Vec[Two](Vector(1.0,3.0))
  println(a + b)
  //--> Vec(2.0, 5.0) 

  val c = Vec[Three](Vector(1.0,3.0)) 
  //--> Exception in thread "main" java.lang.ExceptionInInitializerError
  //-->        at scalatest.vecTest.main(vecTest.scala)
  //--> Caused by: java.lang.IllegalArgumentException: requirement failed

If you don't wish to go down the Peano route, you could always have your Vec be constructed with a D and then use the instance for determine the dimension via the Dim companion object. For instance:

object Dim {
  def dimensionOf(d : Dim) = d match {
    case One => 1
    case Two => 2
    case Three => 3
  }
}
sealed trait Dim

I think for choice, you should be using case objects rather than case classes:

case object One extends Dim
case object Two extends Dim

Then on your vector, you might have to actually store the Dim:

object Vec {
  def vec1 = new Vec[One](One)
  def vec2 = new Vec[Two](Two)
  def vec3 = new Vec[Three](Three)
}

class Vec[D <: Dim](d : D) {
  def dimension : Int = Dim dimensionOf d
  //etc

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM