简体   繁体   English

在Scala中的二叉树上的尾递归折叠

[英]Tail recursive fold on a binary tree in Scala

I am trying to find a tail recursive fold function for a binary tree. 我试图找到二叉树的尾递归折叠函数。 Given the following definitions: 鉴于以下定义:

// From the book "Functional Programming in Scala", page 45
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

Implementing a non tail recursive function is quite straightforward: 实现非尾递归函数非常简单:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B =
  t match {
    case Leaf(v)      => map(v)
    case Branch(l, r) => 
      red(fold(l)(map)(red), fold(r)(map)(red))
  }

But now I am struggling to find a tail recursive fold function so that the annotation @annotation.tailrec can be used. 但现在我正在努力寻找尾递归折叠函数,以便可以使用注释@annotation.tailrec

During my research I have found several examples where tail recursive functions on a tree can eg compute the sum of all leafs using an own stack which is then basically a List[Tree[Int]] . 在我的研究过程中,我发现了一些例子,其中树上的尾递归函数可以例如使用自己的堆栈计算所有叶子的总和,然后基本上是List[Tree[Int]] But as far as I understand in this case it only works for the additions because it is not important whether you first evaluate the left or the right hand side of the operator. 但据我所知,在这种情况下它只适用于添加,因为无论您是首先评估运算符的左侧还是右侧都不重要。 But for a generalised fold it is quite relevant. 但对于广义折叠来说,它是非常相关的。 To show my intension here are some example trees: 为了表明我的意图,这里有一些示例树:

val leafs = Branch(Leaf(1), Leaf(2))
val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3)))
val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4)))
val cmb = Branch(right, Branch(bal, Branch(leafs, left)))
val trees = List(leafs, left, right, bal, cmb)

Based on those trees I want to create a deep copy with the given fold method like: 基于这些树,我想用给定的折叠方法创建一个深层副本,如:

val oldNewPairs = 
  trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _))))

And then proof that the condition of equality holds for all created copies: 然后证明所有创建的副本的平等条件都适用:

val conditionHolds = oldNewPairs.forall(p => {
  if (p._1 == p._2) true
  else {
    println(s"Original:\n${p._1}\nNew:\n${p._2}")
    false
  }
})
println("Condition holds: " + conditionHolds)

Could someone give me some pointers, please? 有人可以给我一些指示吗?

You can find the code used in this question at ScalaFiddle: https://scalafiddle.io/sf/eSKJyp2/15 您可以在ScalaFiddle中找到此问题中使用的代码: https ://scalafiddle.io/sf/eSKJyp2/15

You could reach a tail recursive solution if you stop using the function call stack and start using a stack managed by your code and an accumulator: 如果停止使用函数调用堆栈并开始使用由代码和累加器管理的堆栈,则可以达到尾递归解决方案:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: Vector[B]): Vector[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          val leafRes = map(v)
          foldImp(
            toVisit.tail,
            acc :+ leafRes
          )
        case Branch(l, r) =>
          foldImp(l :: r :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.dropRight(2) ++   Vector(acc.takeRight(2).reduce(red)))
      }
    }

  foldImp(t::Nil, Vector.empty).head

}

The idea is to accumulate values from left to right, keep track of the parenthood relation by the introduction of a stub node and reduce the result using your red function using the last two elements of the accumulator whenever a stub node is found in the exploration. 我们的想法是从左到右累积值,通过引入存根节点跟踪父母关系,并在探索中找到存根节点时使用累加器的最后两个元素使用red函数减少结果。

This solution could be optimized but it is already a tail recursive function implementation. 此解决方案可以进行优化,但它已经是尾递归函数实现。

EDIT: 编辑:

It can be slightly simplified by changing the accumulator data structure to a list seen as a stack: 通过将累加器数据结构更改为看作堆栈的列表,可以略微简化:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: List[B]): List[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          foldImp(
            toVisit.tail,
            map(v)::acc 
          )
        case Branch(l, r) =>
          foldImp(r :: l :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.take(2).reduce(red) :: acc.drop(2))
      }
    }

  foldImp(t::Nil, Nil).head

}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM