简体   繁体   中英

Simple tax calculation in Scala

Suppose I am writing a toy tax calculator with two functions:

// calculate the tax amount for a particular income given tax brackets
def tax(income: BigDecimal, brackets: Seq[Bracket]): BigDecimal = ???

// calculate the min. income for a particular tax rate given tax brackets 
def income(taxRate: BigDecimal, brackets: Seq[Bracket]) = ???

I define a tax bracket like this:

case class Bracket(maxIncomeOpt: Option[BigDecimal], rate: BigDecimal)

Bracket(Some(BigDecimal(10)), BigDecimal(10)) means a tax bracket of 10% for income up tp 10 Bracket(Some(BigDecimal(20)), BigDecimal(20)) means a tax bracket of 20% for income up tp 20
Bracket(None, BigDecimal(30)) means a tax bracket of 30% for any income

Now I am writing function tax like this:

def tax(income: BigDecimal, brackets: Seq[Bracket]): BigDecimal = {
  val (_, result) = brackets.foldLeft((BigDecimal(0), income)) { case ((result, rest), curr) =>
    val taxable = curr.maxIncomeOpt.fold(rest)(_.min(rest))
    (result + taxable * curr.rate / 100.0, rest - taxable)
  }
  result
} 

Function tax seems working but think Seq[Bracket] is not the best way to define tax brackets. The tax brackets is a sorted sequence of disjoint "back-to-back" intervals with an open interval at the end. How would you define tax brackets ?

I would suggest much the same, a List[Tuple[Double, Double]] as the raw form.

Where the tuple is semantically (lower_bound, tax_rate_in_range) . The key difference is that the atomic unit of behaviour is the tax schedule, not the individual brackets. Within a schedule defined by this core data, you can add optimisations if the number brackets becomes large, and you can preserve important invariant for the naive solution like keeping the list ordered by the lower_bound .

I would define tax brackets as a piecewise constant function :

def taxBracket(i: Int): Float = {
  case _ if i < 10 => 0.1
  case _ if i < 20 => 0.2
  case _ => 0.3
}

It's easy to read, easy to have custom behaviours for any kind of value (let's say it becomes linear somewhere or anything in fact, you can just tie the pieces as you want), and calculating the tax for amount N is simply the numerical integration of this function between 0 and N.

Consider solution using algebraic data types to define brackets and PositiveInfinity to simulate open interval

abstract class TaxBracket(val from: Double, val to: Double, val rate: Double) {
  def tax(income: Double) = {
    if (income >= from)
      if (to.isPosInfinity) (income - from) * rate
      else if (income - to > 0) (to - from) * rate
      else (income - (from - 1)) * rate
    else
      0.0
  }
}
case object A extends TaxBracket(0, 12500, 0.0)
case object B extends TaxBracket(12501, 50000, 0.2)
case object C extends TaxBracket(50001, 150000, 0.4)
case object D extends TaxBracket(150001, Double.PositiveInfinity, 0.45)

Now tax calculation simplifies to

def tax(income: Double, bands: List[TaxBracket]): Double =
  bands.map(_.tax(income)).sum

for example, using UK tax bands defined above we get

tax(60000, List(A, B, C, D)) // res0: Double = 11499.8

which can be verified here .

To get the minimum income for given effective tax rate, try

def income(etr: Double, bands: List[TaxBracket]): Option[Double] = {
  bands.map(b => (b.from, b.to)).find { case (from, to) =>
    if (to.isPosInfinity) true
    else (tax(to, bands) / to) >= etr
  }.map { case (lowerBound, upperBound) => lowerBound }
}

income(0.4, List(A, B, C, D)) // res1: Option[Double] = Some(150001.0)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM