开发者

Simplify expression in Scala

I have such开发者_开发百科 case classes:

abstract class Tree

case class Sum(l: Tree, r: Tree) extends Tree

case class Var(n: String) extends Tree

case class Const(v: Int) extends Tree

Now i write such object :

object Main {

  type Environment = String => Int

  def derive(t: Tree, v: String): Tree = t match {
    case Sum(l, r) => Sum(derive(l, v), derive(r, v))
    case Var(n) if (v == n) => Const(1)
    case _ => Const(0)
  }

  def eval(t: Tree, env: Environment): Int = t match {
    case Sum(l, r) => eval(l, env) + eval(r, env)
    case Var(n) => env(n)
    case Const(v) => v
  }

  def simple(t: Tree): Const = t match {
    case Sum(l, r) if (l.isInstanceOf[Const] && r.isInstanceOf[Const]) => Const(l.asInstanceOf[Const].v + r.asInstanceOf[Const].v)
    case Sum(l, r) if (l.isInstanceOf[Sum] && r.isInstanceOf[Sum]) => Const(simple(l).v+ simple(r).v)
    case Sum(l, r) if (l.isInstanceOf[Sum]) => Const(simple(l).v + r.asInstanceOf[Const].v)
    case Sum(l, r) if (r.isInstanceOf[Sum]) => Const(simple(r).v + l.asInstanceOf[Const].v)
  }

  def main(args: Array[String]) {
    val exp: Tree = Sum(Sum(Var("x"), Var("x")), Sum(Const(7), Var("y")))
    val env: Environment = {
      case "x" => 5
      case "y" => 7
    }
    println("Expression: " + exp)
    println("Evaluation with x=5, y=7: " + eval(exp, env))
    println("Derivative relative to x:\n " + derive(exp, "x"))
    println("Derivative relative to y:\n " + derive(exp, "y"))
    println("Simplified expression:\n" + simple(derive(exp, "x")))
  }


}

I am new in scala. Is it possible write method simple with small count of code and maybe in scala way?

Thanks for advice.


You're almost there. In Scala, extractors can be nested:

def simple(t: Tree): Const = t match {
  case Sum(Const(v1), Const(v2)) => Const(v1 + v2)
  case Sum(s1 @ Sum(_,_), s2 @ Sum(_, _)) => Const(simple(s1).v+ simple(s2).v)
  case Sum(s @ Sum(_, _), Const(v)) => Const(simple(s).v + v)
  case Sum(Const(v), s @ Sum(_, _)) => Const(simple(s).v + v)
}

Of course, this will give you some warnings about incomplete matches, and the sx @ Sum(_, _) repeatedly suggests that there may be a better approach that includes matching on Const and Var at the root level and making more recursive calls to simple.


Although this question has been closed, but I think this version should be a better one,

def simplify(t: Tree): Tree = t match {
    case Sum(Const(v1), Const(v2)) => Const(v1 + v2)
    case Sum(Const(v1), Sum(Const(v2), rr)) => simplify(Sum(Const(v1 + v2), simplify(rr)))
    case Sum(l, Const(v)) => simplify(Sum(Const(v), simplify(l)))
    case Sum(l, Sum(Const(v), rr)) => simplify(Sum(Const(v), simplify(Sum(l, rr))))
    case Sum(Sum(ll, lr), r) => simplify(Sum(ll, simplify(Sum(lr, r))))
    case Sum(Var(n), r) => Sum(simplify(r), Var(n))
    case _ => t
}

it seems works with "complex" expressions with variables.


Just a small improvement:

def derive(t: Tree, v: String): Tree = t match {
    case Sum(l, r) => Sum(derive(l, v), derive(r, v))
    case Var(`v`) => Const(1)
    case _ => Const(0)
}


How about this:

def simplify(t: Tree): Tree = t match {
    case Sum(Const(v1),Const(v2)) => Const(v1+v2)
    case Sum(left,right) => simplify(Sum(simplify(left),simplify(right)))
    case _ => t //Not necessary, but for completeness
}

Note that it returns a Tree, not a Const, so it should be able to simplify trees with variables too.

I'm learning Scala so any suggestions as to why this wouldn't work etc. are more than welcome :-)


EDIT: Just discovered that the second case causes an infinite loop when using variables. Substitute it with:

case Sum(left,right) => Sum(simplify(left),simplify(right))

Unfortunately this breaks when left and right return Const, which could be simplified even further (e.g. Sum(Const(2),Const(3))).

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜