简体   繁体   English

如何使用两种参数类型在 Scala 中实现泛型函数?

[英]How to implement generic function in Scala with two argument types?

I'd like to implement a function in Scala that computes the dot product of two numeric sequences as follows我想在 Scala 中实现一个函数来计算两个数字序列的点积,如下所示

val x = Seq(1,2,3.0)
val y = Seq(4,5,6)
val z = (for (a <- x; b <- y) yield a*b).sum
scala> z  : Double = 90.0

val x = Seq(1,2,3)
val y = Seq(4,5,6)
val z = (for (a <- x; b <- y) yield a*b).sum
scala> z  : Int = 90

Notice that if the two sequences are of different types, the result is an Double.请注意,如果两个序列的类型不同,则结果为 Double。 If the two sequences are of the same type (eg Int), the result is an Int.如果两个序列的类型相同(例如 Int),则结果是 Int。

I came up with two alternatives but neither meets the requirement as defined above.我想出了两个替代方案,但都不符合上述定义的要求。

Alternative #1:替代方案#1:

def dotProduct[T: Numeric](x: Seq[T], y: Seq[T]): T = (for (a <- x; b <- y) yield implicitly[Numeric[T]].times(a, b)).sum

This returns the result in the same type as the input, but it can't take two different types.这将返回与输入相同类型的结果,但它不能采用两种不同的类型。

Alternative #2:备选方案#2:

def dotProduct[A, B](x: Seq[A], y: Seq[B])(implicit nx: Numeric[A], ny: Numeric[B]) = (for (a <- x; b <- y) yield nx.toDouble(a)*ny.toDouble(b)).sum

This works for all numeric sequences.这适用于所有数字序列。 However, it always return a Double, even if the two sequences are of the type Int.然而,它总是返回一个 Double,即使两个序列是 Int 类型。

Any suggestion is greatly appreciated.任何建议都非常感谢。

ps The function I implemented above is not "dot product", but simply sum of product of two sequences. ps 我上面实现的函数不是“点积”,而是两个序列的乘积之和。 Thanks Daniel for pointing it out.感谢丹尼尔指出这一点。

Alternative #3 (slightly better than alternatives #1 and #2):备选方案#3(比备选方案#1 和#2 略好):

def sumProduct[T, A <% T, B <% T](x: Seq[A], y: Seq[B])(implicit num: Numeric[T]) = (for (a <- x; b <- y) yield num.times(a,b)).sum

sumProduct(Seq(1,2,3), Seq(4,5,6))  //> res0: Int = 90
sumProduct(Seq(1,2,3.0), Seq(4,5,6))  //> res1: Double = 90.0
sumProduct(Seq(1,2,3), Seq(4,5,6.0))  // Fails!!!

Unfortunately, the View Bound feature (eg "<%") will be deprecated in Scala 2.10.不幸的是,视图绑定特性(例如“<%”)将在 Scala 2.10 中被弃用。

You could create a typeclass that represents the promotion rules:您可以创建一个代表促销规则的类型类:

trait NumericPromotion[A, B, C] {
  def promote(a: A, b: B): (C, C)
}

implicit object IntDoublePromotion extends NumericPromotion[Int, Double, Double] {
  def promote(a: Int, b: Double): (Double, Double) = (a.toDouble, b)
}

def dotProduct[A, B, C]
              (x: Seq[A], y: Seq[B])
              (implicit numEv: Numeric[C], promEv: NumericPromotion[A, B, C])
              : C = {
  val foo = for {
    a <- x
    b <- y
  } yield {
    val (pa, pb) = promEv.promote(a, b)
    numEv.times(pa, pb)
  }

  foo.sum
}

dotProduct[Int, Double, Double](Seq(1, 2, 3), Seq(1.0, 2.0, 3.0))

My typeclass-fu isn't good enough to eliminate the explicit type parameters in the call to dotProduct , nor could I figure out how to avoid the val foo inside the method;我的 typeclass-fu 不足以消除对dotProduct调用中的显式类型参数,我也无法弄清楚如何避免方法中的val foo inlining foo led to compiler errors.内联foo导致编译器错误。 I chalk this up to no having really internalized the implicit resolution rules.我将此归结为没有真正内化隐式解析规则。 Maybe somebody else can get you further.也许其他人可以让你走得更远。

It's also worth mentioning that this is directional;还值得一提的是,这是定向的; you couldn't compute dotProduct(Seq(1.0, 2.0, 3.0), Seq(1, 2, 3)) .你无法计算dotProduct(Seq(1.0, 2.0, 3.0), Seq(1, 2, 3)) But that's easy to fix:但这很容易解决:

implicit def flipNumericPromotion[A, B, C]
                                 (implicit promEv: NumericPromotion[B, A, C])
                                 : NumericPromotion[A, B, C] = 
  new NumericPromotion[A, B, C] {
    override def promote(a: A, b: B): (C, C) = promEv.promote(b, a)
  }

It's also worth mentioning that your code doesn't compute a dot product.还值得一提的是,您的代码不计算点积。 The dot product of [1, 2, 3] and [4, 5, 6] is 4 + 10 + 18 = 32 . [1, 2, 3][4, 5, 6]的点积是4 + 10 + 18 = 32

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM