简体   繁体   中英

cross product of arbitrary number of lists in scala

I have a list of lists in Scala as follows.

val inputList:List[List[Int]] = List(List(1, 2), List(3, 4, 5), List(1, 9))

I want a list of cross products of all the sub-lists.

val desiredOutput: List[List[Int]] = List( 
        List(1, 3, 1), List(1, 3, 9),
        List(1, 4, 1), List(1, 4, 9),
        List(1, 5, 1), List(1, 5, 9),
        List(2, 3, 1), List(2, 3, 9),
        List(2, 4, 1), List(2, 4, 9),
        List(2, 5, 1), List(2, 5, 9))

The number of elements in inputList as well as the sublist are not fixed. What is the Scala way of doing this?

Here is a method that works using recursion. However it is not tail recursive so beware of stackoverflows. However it can be transformed to a tail recursive function by using an auxiliary function.

def getProduct(input:List[List[Int]]):List[List[Int]] = input match{
  case Nil => Nil // just in case you input an empty list
  case head::Nil => head.map(_::Nil) 
  case head::tail => for(elem<- head; sub <- getProduct(tail)) yield elem::sub
}

Test:

scala> getProduct(inputList)
res32: List[List[Int]] = List(List(1, 3, 1), List(1, 3, 9), List(1, 4, 1), List(1, 4, 9), List(1, 5, 1), List(1, 5, 9), List(2, 3, 1), List(2, 3, 9), List(2, 4, 1), List(2, 4, 9), List(2, 5, 1), List(2, 5, 9))

If you use scalaz , this may be a suitable case for Applicative Builder :

import scalaz._
import Scalaz._

def desiredOutput(input: List[List[Int]]) = 
  input.foldLeft(List(List.empty[Int]))((l, r) => (l |@| r)(_ :+ _))

desiredOutput(List(List(1, 2), List(3, 4, 5), List(1, 9)))

I am not very familiar with scalaz myself, and I expect it has some more powerful magic to do this.

Edit

As Travis Brown suggest, we just write

def desiredOutput(input: List[List[Int]]) = input.sequence

And I find the answers of this question very helpful for understanding what sequence does.

If you don't mind a bit of functional programming:

def cross[T](inputs: List[List[T]]) : List[List[T]] = 
    inputs.foldRight(List[List[T]](Nil))((el, rest) => el.flatMap(p => rest.map(p :: _)))

Much fun finding out how that works. :-)

After several attempts, I arrived at this solution.

val inputList: List[List[Int]] = List(List(1, 2), List(3, 4, 5), List(1, 9))
val zss: List[List[Int]] = List(List())
def fun(xs: List[Int], zss: List[List[Int]]): List[List[Int]] = {
    for {
        x <- xs
        zs <- zss
    } yield {
        x :: zs
    }
}
val crossProd: List[List[Int]] = inputList.foldRight(zss)(fun _)
println(crossProd)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM