开发者

Tail Recursion Vs. Refactoring

I have this method

  private def getAddresses(data: List[Int], count: Int, len: Int): Tuple2[List[Address], List[Int]] = {
    if (count == len) {
      (List.empty, List.empty)
    } else {
      val byteAddress = data.takeWhile(_ != 0)
      val newData = data.dropWhile(_ != 0).tail
      val newCount = count + 1
      val newPosition = byteAddress.length + 1
      val destFlag = byteAddress.head
      if (destFlag == SMEAddressFlag) {
        (SMEAddress().fromBytes(byteAddress) :: getAddresses(newData, newCount, len)._1, newPosition :: getAddresses(newData, newCount, len)._2)
      } else {
        (DistributionList().fromBytes(byteAddress) :: getAddresses(newData, newCount, len)._1, newPosition :: getAddresses(newData, newCount, len)._2)
      }
    }
  }

I am tempted to rewrite it thus

  private def getAddresses(data: List[Int], count: Int, len: Int): Tuple2[List[Address], List[Int]] = {
    if (count == l开发者_运维问答en) {
      (List.empty, List.empty)
    } else {
      val byteAddress = data.takeWhile(_ != 0)
      val newData = data.dropWhile(_ != 0).tail
      val newCount = count + 1
      val newPosition = byteAddress.length + 1
      val destFlag = byteAddress.head
      val nextIter = getAddresses(newData, newCount, len)
      if (destFlag == SMEAddressFlag) {
        (SMEAddress().fromBytes(byteAddress) :: nextIter._1, newPosition :: nextIter._2)
      } else {
        (DistributionList().fromBytes(byteAddress) :: nextIter._1, newPosition :: nextIter._2)
      }
    }
  }

My questions are

  1. Is the first one tail recursive?
  2. How does it work? I am calling the method twice in the last line. Does it evaluate to separate method calls what?
  3. Which version is more efficient, or how can I write it more efficiently.

Pardon me if the code smells I am new to scala.

Thanks


Neither of these is tail recursion. That only happens when recursive calls occur only as the very last item along some execution path. If you could replace the call by a jump to the start of the code, saving no state but relabeling the input variables, then it's tail recursion. (Internally, that's exactly what the compiler does.)

To convert an ordinary recursive function into a tail recursive function, when doing so is possible, you need to pass forward any stored data, like so:

private def getAddresses(
  data: List[Int], count: Int, len: Int,    // You had this already
  addresses: List[Address] = Nil,           // Build this as we go, starting with nothing
  positions: List[Int] = Nil                // Same here
): (List[Address], List[Int]) {
  if (count==len) {
    (addresses.reverse, positions.reverse)  // Return what we've built, reverse to fix order
  }    
  else {
    val (addr,rest) = data.span(_ != 0)
    val newdata = rest.tail
    val position = addr.length + 1
    val flag = addr.head
    val address = (
      if (flag) SMEAddress().fromBytes(addr)
      else DistributionList().fromBytes(addr)
    )
    getAddresses(newdata, count+1, len, address :: addresses, position :: positions)
  }
}

Tail-recursive versions are more efficient than non-tail-recursive versions, if all else is equal. (In this case, it might not be since the list has to be reversed at the end, but it has the huge advantage that it won't overflow the stack if you use a large len.)

Calling a method twice always runs it twice. There is no automatic memoization of results of method calls--this would be extremely difficult to do automatically.


Just to make it more 'Scala'y, you could define the tail-recursive function internally to the getAddresses like so

def getAddresses(data: List[Int], len: Int) = {
  def inner(count: Int, addresses: List[Address] = Nil, positions: List[Int] = Nil): (List[Address], List[Int]) = {
    if (count == len) {
     (addresses.reverse, positions.reverse)
    } else {
      val (byteAddress, rest) = data.span(_ != 0)
      val newData = rest.tail
      val newPosition = byteAddress.length + 1
      val destFlag = byteAddress.head
      val newAddress = (if (destFlag == SMEAddressFlag) SMEAddress else DistributionList()).fromBytes(byteAddress)
      inner(count+1, newAddress :: addresses, newPosition :: positions)
    }
  }

  inner(0)   //Could have made count have a default too
}


Since all of your inputs are immutable, and your resulting lists will both always have the same length, I thought of this solution instead.

private def getAddresses(data:List[Int], count:Int, len:Int):Stream[(Address,Int)] = {
   if (count == len) {
      Stream.empty
   }else{
      val (byteAddress, _::newData) = data.span(_ != 0)
      val newAddress =
         if (byteAddress.head == SMEAddressFlag) SMEAddress().fromBytes(byteAddress)
         else DistributionList().fromBytes(byteAddress)

      (newAddress, byteAddress.length + 1) #:: getAddresses(newData, count+1, len)
}
  1. Instead of returning a pair of lists, it returns a list of pairs. This makes it easy to recurse once. If you need separate lists, you can use map to extract them, but you might be able to restruture other parts of your program to more cleanly work with a list of pairs rather than taking 2 lists as a parameter all over.

  2. Instead of returning a list, it returns a stream which is lazily evaluated. This isn't tail recursion, but the way that streams are lazily evaluated also prevents stack overflows. If you need a strict list, you can call toList on the result of this function.

  3. The demonstrates other useful techniques, for example the use of span with pattern matching to compute byteAddress and newData in a single line of code. You can add back some of the vals that I removed if it's useful to have their names for readability.


  1. No. A method is tail-recursive if the last call in a given execution path is a recursive call. The last call evaluated here will be ::, which is not a recursive call, so it's not tail recursive.
  2. Yes, if you call a method twice, it will be evaluated twice.
  3. The second one is more efficient as here you're only calling the method once.
0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜