简体   繁体   中英

flatMap() in Scala during recursion

I was going through a Scala-99 problem to reduce a complex nested list into a flat list. Code given below:

def flatten(l: List[Any]): List[Any] =  l flatMap {
  case ms:List[_] => flatten(ms)
  case l => List(l)
}
val L = List(List(1, 1), 2, List(3, List(5, 8)))
val flattenedList = flatten(L)

For the given input above L , I understood this problem by drawing a tree (given below)

List(List(1, 1), 2, List(3, List(5, 8)))                            (1)
|                \             \ 
List(1, 1)       List(2)    List(3, List(5, 8))                     (2)
|        \                       |        \ 
List(1)  List(1)              List(3)     List(5, 8)                (3)
                                                 |   \
                                            List(5)  List(8)        (4)

What I've understood is that, the program results in the leaf nodes being added in a list maintained by Scala internally, like:

li = List(List(1), List(1), List(2), List(3), List(5), List(8))

and then the result is passed to the flatten method which results in the final answer:

List(1, 1, 2, 3, 5, 8)

Is my understanding correct?

EDIT: I'm sorry, I forgot to add this:

I wanted to ask that if my understanding is correct then why does replacing flatMap with map in the flatten 's definition above produces this list:

List(List(List(1), List(1)), List(2), List(List(3), List(List(5), List(8))))

I mean isn't flatMap just map then flatten . Shouldn't I be getting like the one I mentioned above:

li = List(List(1), List(1), List(2), List(3), List(5), List(8))

You're right that flatMap is just map and flatten but note that this flatten is not the same flatten you define, for list it only concatenate inner lists at 1 level. One very useful way to unpack these is to use substitution model, just like maths

if I define it like this, (calling it f to avoid confusion with flatten here and flatten in std library)

def f(l: List[Any]): List[Any] =  l map {
  case ms:List[_] => f(ms)
  case l => List(l)
}

then

f(List(  List(1,    1),    2))
= List(f(List(1,    1)), f(2)) // apply f to element of the outer most list
= List(List(f(1), f(1)), f(2)) // apply f to element of the inner list
= List(List(List(1), List(1)), List(2))) // no more recursion

Notice map doesn't change the structure of your list, it only applies the function to each element. This should explains how you have the result if you replace flatMap with map

Now if you have flatMap instead of map , then the flatten step is simply concatenating

def f(l: List[Any]): List[Any] =  l flatMap {
  case ms:List[_] => f(ms)
  case l => List(l)
}

then

f(List(List(1,1), 2))
= f(List(1,1))         ++ f(2) // apply f to each element and concatenate
= (f(1) ++ f(1))       ++ f(2)
= (List(1) ++ List(1)) ++ List(2)
= List( 1,1)           ++ List(2)
= List(1,2,3)

or in another way, using flatten instead of ++

f(        List(     List(1,1),    2))
= flatten(List( f(      List(  1,    1)) , f(2))) // map and flatten
= flatten(List( flatten(List(f(1), f(1))), f(2))) // again map and flatten
= flatten(List( flatten(List(List(1), List(1))), List(2))))

now you can see that flatten is called multiple times, at every level where you recursively apply f which will collapse your tree 1 level at a time into just 1 big list.

To answer your comment: why is List(1,1) is turned into flatten(List(List(1), List(1)) . It's because this is the simple case, but consider List(1, List(2)) , then f will be applied for 1 and List(2) . Because the next step is to 'flatten' (in stdlib) then both 1 & List(2) must be turned into a List so that it is in the right shape

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM