开发者

Transforming expression given in prefix notation, identifying common subexpressions and dependencies

I am given a bunch of expressions in prefix notation in an ANSI text file. I would like to produce another ANSI text file containing the step-by-step evaluation of these expressions. For example:

- + ^ x 2 ^ y 2 1

should be turned into

t1 = x^2
t2 = y^2
t3 = t1 + t2
t4 = t3 - 1
t4 is the result

I also have to identify common subexpressions. For example given

expression_1: z = ^ x 2
expression_2: - + z ^ y 2 1
expression_3: - z y

I have to generate an output saying that x appears in expressions 1, 2 and 3 (through z).

I have to identify dependecies: expression_1 d开发者_运维问答epends only on x, expression_2 depends on x and y, etc.

The original problem is more difficult than the examples above and I have no control over the input format, it is in prefix notation in a much more complicated way than the above ones.

I already have a working implementation in C++ however it is a lot of pain doing such things in C++.

What programming language is best suited for these type problems?

Could you recommend a tutorial / website / book where I could start?

What keywords should I look for?

UPDATE: Based on the answers, the above examples are somewhat unfortunate, I have unary, binary and n-ary operators in the input. (If you are wondering, exp is an unary operator, sum over a range is an n-ary operator.)


To give you an idea how this would look like in Python, here is some example code:

operators = "+-*/^"

def parse(it, count=1):
    token = next(it)
    if token in operators:
        op1, count = parse(it, count)
        op2, count = parse(it, count)
        tmp = "t%s" % count
        print tmp, "=", op1, token, op2
        return tmp, count + 1
    return token, count

s = "- + ^ x 2 ^ y 2 1"
a = s.split()
res, dummy = parse(iter(a))
print res, "is the result"

The output is the same as your example output.

This example aside, I think any of the high-level languages you listed are almost equally suited for the task.


The sympy python package does symbolic algebra, including common subexpression elimination and generating evaluation steps for a set of expressions.

See: http://docs.sympy.org/dev/modules/rewriting.html (Look at the cse method at the bottom of the page).


The Python example is elegantly short, but I suspect that you won't actually get enough control over your expressions that way. You're much better off actually building an expression tree, even though it takes more work, and then querying the tree. Here's an example in Scala (suitable for cutting and pasting into the REPL):

object OpParser {
  private def estr(oe: Option[Expr]) = oe.map(_.toString).getOrElse("_")
  case class Expr(text: String, left: Option[Expr] = None, right: Option[Expr] = None) {
    import Expr._
    def varsUsed: Set[String] = text match {
      case Variable(v) => Set(v)
      case Op(o) =>
        left.map(_.varsUsed).getOrElse(Set()) ++ right.map(_.varsUsed).getOrElse(Set())
      case _ => Set()
    }
    def printTemp(n: Int = 0, depth: Int = 0): (String,Int) = text match {
      case Op(o) => 
        val (tl,nl) = left.map(_.printTemp(n,depth+1)).getOrElse(("_",n))
        val (tr,nr) = right.map(_.printTemp(nl,depth+1)).getOrElse(("_",n))
        val t = "t"+(nr+1)
        println(t + " = " + tl + " " + text + " " + tr)
        if (depth==0) println(t + " is the result")
        (t, nr+1)
      case _ => (text, n)
    }
    override def toString: String = {
      if (left.isDefined || right.isDefined) {
        "(" + estr(left) + " " + text + " " + estr(right) + ")"
      }
      else text
    }
  }
  object Expr {
    val Digit = "([0-9]+)"r
    val Variable = "([a-z])"r
    val Op = """([+\-*/^])"""r
    def parse(s: String) = {
      val bits = s.split(" ")
      val parsed = (
        if (bits.length > 2 && Variable.unapplySeq(bits(0)).isDefined && bits(1)=="=") {
          parseParts(bits,2)
        }
        else parseParts(bits)
      )
      parsed.flatMap(p => if (p._2<bits.length) None else Some(p._1))
    }
    def parseParts(as: Array[String], from: Int = 0): Option[(Expr,Int)] = {
      if (from >= as.length) None
      else as(from) match {
        case Digit(n) => Some(Expr(n), from+1)
        case Variable(v) => Some(Expr(v), from+1)
        case Op(o) =>
          parseParts(as, from+1).flatMap(lhs =>
            parseParts(as, lhs._2).map(rhs => (Expr(o,Some(lhs._1),Some(rhs._1)), rhs._2))
          )
        case _ => None
      }
    }
  }
}

This may be a little much to digest all at once, but then again, this does rather a lot.

Firstly, it's completely bulletproof (note the heavy use of Option where a result might fail). If you throw garbage at it, it will just return None. (With a bit more work, you could make it complain about the problem in an informative way--basically the case Op(o) which then does parseParts nested twice could instead store the results and print out an informative error message if the op didn't get two arguments. Likewise, parse could complain about trailing values instead of just throwing back None.)

Secondly, when you're done with it, you have a complete expression tree. Note that printTemp prints out the temporary variables you wanted, and varsUsed lists the variables used in a particular expression, which you can use to expand to a full list once you parse multiple lines. (You might need to fiddle with the regexp a little if your variables can be more than just a to z.) Note also that the expression tree prints itself out in normal infix notation. Let's look at some examples:

scala> OpParser.Expr.parse("4")
res0: Option[OpParser.Expr] = Some(4)

scala> OpParser.Expr.parse("+ + + + + 1 2 3 4 5 6")
res1: Option[OpParser.Expr] = Some((((((1 + 2) + 3) + 4) + 5) + 6))

scala> OpParser.Expr.parse("- + ^ x 2 ^ y 2 1")
res2: Option[OpParser.Expr] = Some((((x ^ 2) + (y ^ 2)) - 1))

scala> OpParser.Expr.parse("+ + 4 4 4 4") // Too many 4s!
res3: Option[OpParser.Expr] = None

scala> OpParser.Expr.parse("Q#$S!M$#!*)000") // Garbage!
res4: Option[OpParser.Expr] = None

scala> OpParser.Expr.parse("z =") // Assigned nothing?! 
res5: Option[OpParser.Expr] = None

scala> res2.foreach(_.printTemp())
t1 = x ^ 2
t2 = y ^ 2
t3 = t1 + t2
t4 = t3 - 1
t4 is the result

scala> res2.map(_.varsUsed)       
res10: Option[Set[String]] = Some(Set(x, y))

Now, you could do this in Python also without too much additional work, and in a number of the other languages besides. I prefer to use Scala, but you may prefer otherwise. Regardless, I do recommend creating the full expression tree if you want to retain maximum flexibility for handling tricky cases.


Prefix notation is really simple to do with plain recursive parsers. For instance:

object Parser {
    val Subexprs = collection.mutable.Map[String, String]()
    val Dependencies = collection.mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)
    val TwoArgsOp = "([-+*/^])".r     // - at the beginning, ^ at the end
    val Ident = "(\\p{Alpha}\\w*)".r
    val Literal = "(\\d+)".r

    var counter = 1
    def getIdent = {
        val ident = "t" + counter
        counter += 1
        ident
    }

    def makeOp(op: String) = {
        val op1 = expr
        val op2 = expr
        val ident = getIdent 
        val subexpr = op1 + " " + op + " " + op2
        Subexprs(ident)  = subexpr
        Dependencies(ident) = Dependencies(op1) ++ Dependencies(op2) + op1 + op2
        ident
    }

    def expr: String = nextToken match {
        case TwoArgsOp(op) => makeOp(op)
        case Ident(id)     => id
        case Literal(lit)  => lit
        case x             => error("Unknown token "+x)
    }

    def nextToken = tokens.next
    var tokens: Iterator[String] = _

    def parse(input: String) = {
        tokens = input.trim split "\\s+" toIterator;
        counter = 1
        expr
        if (tokens.hasNext)
            error("Input not fully parsed: "+tokens.mkString(" "))

        (Subexprs, Dependencies)
    }
}

This will generate output like this:

scala> val (subexpressions, dependencies) = Parser.parse("- + ^ x 2 ^ y 2 1")
subexpressions: scala.collection.mutable.Map[String,String] = Map(t3 -> t1 + t2, t4 -> t3 - 1, t1 -> x ^ 2, t2 -> y ^ 2)
dependencies: scala.collection.mutable.Map[String,Set[String]] = Map(t3 -> Set(x, y, t2, 2, t1), t4 -> Set(x, y, t3, t2, 1, 2, t1), t1 -> Set(x, 2), t
2 -> Set(y, 2))

scala> subexpressions.toSeq.sorted foreach println
(t1,x ^ 2)
(t2,y ^ 2)
(t3,t1 + t2)
(t4,t3 - 1)

scala> dependencies.toSeq.sortBy(_._1) foreach println
(t1,Set(x, 2))
(t2,Set(y, 2))
(t3,Set(x, y, t2, 2, t1))
(t4,Set(x, y, t3, t2, 1, 2, t1))

This can be easily expanded. For instance, to handle multiple expression statements you can use this:

object Parser {
    val Subexprs = collection.mutable.Map[String, String]()
    val Dependencies = collection.mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)
    val TwoArgsOp = "([-+*/^])".r     // - at the beginning, ^ at the end
    val Ident = "(\\p{Alpha}\\w*)".r
    val Literal = "(\\d+)".r

    var counter = 1
    def getIdent = {
        val ident = "t" + counter
        counter += 1
        ident
    }

    def makeOp(op: String) = {
        val op1 = expr
        val op2 = expr
        val ident = getIdent 
        val subexpr = op1 + " " + op + " " + op2
        Subexprs(ident)  = subexpr
        Dependencies(ident) = Dependencies(op1) ++ Dependencies(op2) + op1 + op2
        ident
    }

    def expr: String = nextToken match {
        case TwoArgsOp(op) => makeOp(op)
        case Ident(id)     => id
        case Literal(lit)  => lit
        case x             => error("Unknown token "+x)
    }

    def assignment: Unit = {
        val ident = nextToken
        nextToken match {
            case "=" => 
                val tmpIdent = expr
                Dependencies(ident) = Dependencies(tmpIdent)
                Subexprs(ident) = Subexprs(tmpIdent)
                Dependencies.remove(tmpIdent)
                Subexprs.remove(tmpIdent)
            case x   => error("Expected assignment, got "+x)
        }
     }

    def stmts: Unit = while(tokens.hasNext) tokens.head match {
        case TwoArgsOp(_) => expr
        case Ident(_)     => assignment
        case x            => error("Unknown statement starting with "+x)
    }

    def nextToken = tokens.next
    var tokens: BufferedIterator[String] = _

    def parse(input: String) = {
        tokens = (input.trim split "\\s+" toIterator).buffered
        counter = 1
        stmts
        if (tokens.hasNext)
            error("Input not fully parsed: "+tokens.mkString(" "))

        (Subexprs, Dependencies)
    }
}

Yielding:

scala> val (subexpressions, dependencies) = Parser.parse("""
     | z = ^ x 2
     | - + z ^ y 2 1
     | - z y
     | """)
subexpressions: scala.collection.mutable.Map[String,String] = Map(t3 -> z + t2, t5 -> z - y, t4 -> t3 - 1, z -> x ^ 2, t2 -> y ^ 2)
dependencies: scala.collection.mutable.Map[String,Set[String]] = Map(t3 -> Set(x, y, t2, 2, z), t5 -> Set(x, 2, z, y), t4 -> Set(x, y, t3, t2, 1, 2, z
), z -> Set(x, 2), t2 -> Set(y, 2))

scala> subexpressions.toSeq.sorted foreach println
(t2,y ^ 2)
(t3,z + t2)
(t4,t3 - 1)
(t5,z - y)
(z,x ^ 2)

scala> dependencies.toSeq.sortBy(_._1) foreach println
(t2,Set(y, 2))
(t3,Set(x, y, t2, 2, z))
(t4,Set(x, y, t3, t2, 1, 2, z))
(t5,Set(x, 2, z, y))
(z,Set(x, 2))


Ok, since recursive parsers are not your thing, here's an alternative with parse combinators:

object PrefixParser extends JavaTokenParsers {
    import scala.collection.mutable

    // Maps generated through parsing

    val Subexprs = mutable.Map[String, String]()
    val Dependencies = mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)

    // Initialize, read, parse & evaluate string

    def read(input: String) = {
        counter = 1
        Subexprs.clear
        Dependencies.clear
        parseAll(stmts, input)
    }

    // Grammar

    def stmts               = stmt+
    def stmt                = assignment | expr
    def assignment          = (ident <~ "=") ~ expr ^^ assignOp
    def expr: P             = subexpr | identifier | number
    def subexpr: P          = twoArgs | nArgs
    def twoArgs: P          = operator ~ expr ~ expr ^^ twoArgsOp
    def nArgs: P            = "sum" ~ ("\\d+".r >> args) ^^ nArgsOp
    def args(n: String): Ps = repN(n.toInt, expr)
    def operator            = "[-+*/^]".r
    def identifier          = ident ^^ (id => Result(id, Set(id)))
    def number              = wholeNumber ^^ (Result(_, Set.empty))

    // Evaluation helper class and types

    case class Result(ident: String, dependencies: Set[String])
    type P = Parser[Result]
    type Ps = Parser[List[Result]]

    // Evaluation methods

    def assignOp: (String ~ Result) => Result = {
        case ident ~ result => 
            val value = assign(ident, 
                               Subexprs(result.ident),
                               result.dependencies - result.ident)
            Subexprs.remove(result.ident)
            Dependencies.remove(result.ident)
            value
    }

    def assign(ident: String, 
               value: String, 
               dependencies: Set[String]): Result = {
        Subexprs(ident) = value
        Dependencies(ident) = dependencies
        Result(ident, dependencies)
    }

    def twoArgsOp: (String ~ Result ~ Result) => Result = { 
        case op ~ op1 ~ op2 => makeOp(op, op1, op2) 
    }

    def makeOp(op: String, 
               op1: Result, 
               op2: Result): Result = {
        val ident = getIdent
        assign(ident, 
               "%s %s %s" format (op1.ident, op, op2.ident),
               op1.dependencies ++ op2.dependencies + ident)
    } 

    def nArgsOp: (String ~ List[Result]) => Result = { 
        case op ~ ops => makeNOp(op, ops) 
    }

    def makeNOp(op: String, ops: List[Result]): Result = {
        val ident = getIdent
        assign(ident, 
               "%s(%s)" format (op, ops map (_.ident) mkString ", "),
               ops.foldLeft(Set(ident))(_ ++ _.dependencies))
    } 

    var counter = 1
    def getIdent = {
        val ident = "t" + counter
        counter += 1
        ident
    }

    // Debugging helper methods

    def printAssignments = Subexprs.toSeq.sorted foreach println
    def printDependencies = Dependencies.toSeq.sortBy(_._1) map {
        case (id, dependencies) => (id, dependencies - id)
    } foreach println

}

This is the kind of results you get:

scala> PrefixParser.read("""
     | z = ^ x 2
     | - + z ^ y 2 1
     | - z y
     | """)
res77: PrefixParser.ParseResult[List[PrefixParser.Result]] = [5.1] parsed: List(Result(z,Set(x)), Result(t4,Set(t4, y, t3, t2, z)), Result(t5,Set(z, y
, t5)))

scala> PrefixParser.printAssignments
(t2,y ^ 2)
(t3,z + t2)
(t4,t3 - 1)
(t5,z - y)
(z,x ^ 2)

scala> PrefixParser.printDependencies
(t2,Set(y))
(t3,Set(z, y, t2))
(t4,Set(y, t3, t2, z))
(t5,Set(z, y))
(z,Set(x))

n-Ary operator

scala> PrefixParser.read("""
     | x = sum 3 + 1 2 * 3 4 5
     | * x x
     | """)
res93: PrefixParser.ParseResult[List[PrefixParser.Result]] = [4.1] parsed: List(Result(x,Set(t1, t2)), Result(t4,Set(x, t4)))

scala> PrefixParser.printAssignments
(t1,1 + 2)
(t2,3 * 4)
(t4,x * x)
(x,sum(t1, t2, 5))

scala> PrefixParser.printDependencies
(t1,Set())
(t2,Set())
(t4,Set(x))
(x,Set(t1, t2))


It turns out that this sort of parsing is of interest to me also, so I've done a bit more work on it.

There seems to be a sentiment that things like simplification of expressions is hard. I'm not so sure. Let's take a look at a fairly complete solution. (The printing out of tn expressions is not useful for me, and you've got several Scala examples already, so I'll skip that.)

First, we need to extract the various parts of the language. I'll pick regular expressions, though parser combinators could be used also:

object OpParser {
  val Natural = "([0-9]+)"r
  val Number = """((?:-)?[0-9]+(?:\.[0-9]+)?(?:[eE](?:-)?[0-9]+)?)"""r
  val Variable = "([a-z])"r
  val Unary = "(exp|sin|cos|tan|sqrt)"r
  val Binary = "([-+*/^])"r
  val Nary = "(sum|prod|list)"r

Pretty straightforward. We define the various things that might appear. (I've decided that user-defined variables can only be a single lowercase letter, and that numbers can be floating-point since you have the exp function.) The r at the end means this is a regular expression, and it will give us the stuff in parentheses.

Now we need to represent our tree. There are a number of ways to do this, but I'll choose an abstract base class with specific expressions as case classes, since this makes pattern matching easy. Furthermore, we might want nice printing, so we'll override toString. Mostly, though, we'll use recursive functions to do the heavy lifting.

  abstract class Expr {
    def text: String
    def args: List[Expr]
    override def toString = args match {
      case l :: r :: Nil => "(" + l + " " + text + " " + r + ")"
      case Nil => text
      case _ => args.mkString(text+"(", ",", ")")
    }
  }
  case class Num(text: String, args: List[Expr]) extends Expr {
    val quantity = text.toDouble
  }
  case class Var(text: String, args: List[Expr]) extends Expr {
    override def toString = args match {
      case arg :: Nil => "(" + text + " <- " + arg + ")"
      case _ => text
    }
  }
  case class Una(text: String, args: List[Expr]) extends Expr
  case class Bin(text: String, args: List[Expr]) extends Expr
  case class Nar(text: String, args: List[Expr]) extends Expr {
    override def toString = text match {
      case "list" =>
        (for ((a,i) <- args.zipWithIndex) yield {
           "%3d: %s".format(i+1,a.toString)
        }).mkString("List[\n","\n","\n]")
      case _ => super.toString
    }
  }

Mostly this is pretty dull--each case class overrides the base class, and the text and args automatically fill in for the def. Note that I've decided that a list is a possible n-ary function, and that it will be printed out with line numbers. (The reason is that if you have multiple lines of input, it's sometimes more convenient to work with them all together as one expression; this lets them be one function.)

Once our data structures are defined, we need to parse the expressions. It's convenient to represent the stuff to parse as a list of tokens; as we parse, we'll return both an expression and the remaining tokens that we haven't parsed--this is a particularly useful structure for recursive parsing. Of course, we might fail to parse anything, so it had better be wrapped in an Option also.

  def parse(tokens: List[String]): Option[(Expr,List[String])] = tokens match {
    case Variable(x) :: "=" :: rest =>
      for ((expr,remains) <- parse(rest)) yield (Var(x,List(expr)), remains)
    case Variable(x) :: rest => Some(Var(x,Nil), rest)
    case Number(n) :: rest => Some(Num(n,Nil), rest)
    case Unary(u) :: rest =>
      for ((expr,remains) <- parse(rest)) yield (Una(u,List(expr)), remains)
    case Binary(b) :: rest =>
      for ((lexp,lrem) <- parse(rest); (rexp,rrem) <- parse(lrem)) yield
        (Bin(b,List(lexp,rexp)), rrem)
    case Nary(a) :: Natural(b) :: rest =>
      val buffer = new collection.mutable.ArrayBuffer[Expr]
      def parseN(tok: List[String], n: Int = b.toInt): List[String] = {
        if (n <= 0) tok
        else {
          for ((expr,remains) <- parse(tok)) yield { buffer += expr; parseN(remains, n-1) }
        }.getOrElse(tok)
      }
      val remains = parseN(rest)
      if (buffer.length == b.toInt) Some( Nar(a,buffer.toList), remains )
      else None
    case _ => None
  }

Note that we use pattern matching and recursion to do most of the heavy lifting--we pick off part of the list, figure out how many arguments we need, and pass those along recursively. The N-ary operation is a little less friendly, but we create a little recursive function that will parse N things at a time for us, storing the results in a buffer.

Of course, this is a little unfriendly to use, so we add some wrapper functions that let us interface with it nicely:

  def parse(s: String): Option[Expr] = parse(s.split(" ").toList).flatMap(x => {
    if (x._2.isEmpty) Some(x._1) else None
  })
  def parseLines(ls: List[String]): Option[Expr] = {
    val attempt = ls.map(parse).flatten
    if (attempt.length<ls.length) None
    else if (attempt.length==1) attempt.headOption
    else Some(Nar("list",attempt))
  }

Okay, now, what about simplification? One thing we might want to do is numeric simplification, where we precompute the expressions and replace the original expression with the reduced version thereof. That sounds like some sort of a recursive operation--find numbers, and combine them. First we get some helper functions to do calculations on numbers:

  def calc(n: Num, f: Double => Double): Num = Num(f(n.quantity).toString, Nil)
  def calc(n: Num, m: Num, f: (Double,Double) => Double): Num =
    Num(f(n.quantity,m.quantity).toString, Nil)
  def calc(ln: List[Num], f: (Double,Double) => Double): Num =
    Num(ln.map(_.quantity).reduceLeft(f).toString, Nil)

and then we do the simplification:

  def numericSimplify(expr: Expr): Expr = expr match {
    case Una(t,List(e)) => numericSimplify(e) match {
      case n @ Num(_,_) => t match {
        case "exp" => calc(n, math.exp _)
        case "sin" => calc(n, math.sin _)
        case "cos" => calc(n, math.cos _)
        case "tan" => calc(n, math.tan _)
        case "sqrt" => calc(n, math.sqrt _)
      }
      case a => Una(t,List(a))
    }
    case Bin(t,List(l,r)) => (numericSimplify(l), numericSimplify(r)) match {
      case (n @ Num(_,_), m @ Num(_,_)) => t match {
        case "+" => calc(n, m, _ + _)
        case "-" => calc(n, m, _ - _)
        case "*" => calc(n, m, _ * _)
        case "/" => calc(n, m, _ / _)
        case "^" => calc(n, m, math.pow)
      }
      case (a,b) => Bin(t,List(a,b))
    }
    case Nar("list",list) => Nar("list",list.map(numericSimplify))
    case Nar(t,list) =>
      val simple = list.map(numericSimplify)
      val nums = simple.collect { case n @ Num(_,_) => n }
      if (simple.length == 0) t match {
        case "sum" => Num("0",Nil)
        case "prod" => Num("1",Nil)
      }
      else if (nums.length == simple.length) t match {
        case "sum" => calc(nums, _ + _)
        case "prod" => calc(nums, _ * _)
      }
      else Nar(t, simple)
    case Var(t,List(e)) => Var(t, List(numericSimplify(e)))
    case _ => expr
  }

Notice again the heavy use of pattern matching to find when we're in a good case, and to dispatch the appropriate calculation.

Now, surely algebraic substitution is much more difficult! Actually, all you need to do is notice that an expression has already been used, and assign a variable. Since the syntax I've defined above allows in-place variable substitution, we can actually just modify our expression tree to include more variable assignments. So we do (edited to only insert variables if the user hasn't):

  def algebraicSimplify(expr: Expr): Expr = {
    val all, dup, used = new collection.mutable.HashSet[Expr]
    val made = new collection.mutable.HashMap[Expr,Int]
    val user = new collection.mutable.HashMap[Expr,Expr]
    def findExpr(e: Expr) {
      e match {
        case Var(t,List(v)) =>
          user += v -> e
          if (all contains e) dup += e else all += e
        case Var(_,_) | Num(_,_) => // Do nothing in these cases
        case _ => if (all contains e) dup += e else all += e
      }
      e.args.foreach(findExpr)
    }
    findExpr(expr)
    def replaceDup(e: Expr): Expr = {
      if (made contains e) Var("x"+made(e),Nil)
      else if (used contains e) Var(user(e).text,Nil)
      else if (dup contains e) {
        val fixed = replaceDupChildren(e)
        made += e -> made.size
        Var("x"+made(e),List(fixed))
      }
      else replaceDupChildren(e)
    }
    def replaceDupChildren(e: Expr): Expr = e match {
      case Una(t,List(u)) => Una(t,List(replaceDup(u)))
      case Bin(t,List(l,r)) => Bin(t,List(replaceDup(l),replaceDup(r)))
      case Nar(t,list) => Nar(t,list.map(replaceDup))
      case Var(t,List(v)) =>
        used += v
        Var(t,List(if (made contains v) replaceDup(v) else replaceDupChildren(v)))
      case _ => e
    }
    replaceDup(expr)
  }

That's it--a fully functional algebraic replacement routine. Note that it builds up sets of expressions that it's seen, keeping special track of which ones are duplicates. Thanks to the magic of case classes, all the equalities are defined for us, so it just works. Then we can replace any duplicates as we recurse through to find them. Note that the replace routine is split in half, and that it matches on an unreplaced version of the tree, but uses a replaced version.

Okay, now let's add a few tests:

  def main(args: Array[String]) {
    val test1 = "- + ^ x 2 ^ y 2 1"
    val test2 = "+ + +"  // Bad!
    val test3 = "exp sin cos sum 5"  // Bad!
    val test4 = "+ * 2 3 ^ 3 2"
    val test5 = List(test1, test4, "^ y 2").mkString("list 3 "," ","")
    val test6 = "+ + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y"

    def performTest(test: String) = {
      println("Start with: " + test)
      val p = OpParser.parse(test)
      if (p.isEmpty) println("  Parsing failed")
      else {
        println("Parsed:     " + p.get)
        val q = OpParser.numericSimplify(p.get)
        println("Numeric:    " + q)
        val r = OpParser.algebraicSimplify(q)
        println("Algebraic:  " + r)
      }
      println
    }

    List(test1,test2,test3,test4,test5,test6).foreach(performTest)
  }
}

How does it do?

$ scalac OpParser.scala; scala OpParser
Start with: - + ^ x 2 ^ y 2 1
Parsed:     (((x ^ 2) + (y ^ 2)) - 1)
Numeric:    (((x ^ 2) + (y ^ 2)) - 1)
Algebraic:  (((x ^ 2) + (y ^ 2)) - 1)

Start with: + + +
  Parsing failed

Start with: exp sin cos sum 5
  Parsing failed

Start with: + * 2 3 ^ 3 2
Parsed:     ((2 * 3) + (3 ^ 2))
Numeric:    15.0
Algebraic:  15.0

Start with: list 3 - + ^ x 2 ^ y 2 1 + * 2 3 ^ 3 2 ^ y 2
Parsed:     List[
  1: (((x ^ 2) + (y ^ 2)) - 1)
  2: ((2 * 3) + (3 ^ 2))
  3: (y ^ 2)
]
Numeric:    List[
  1: (((x ^ 2) + (y ^ 2)) - 1)
  2: 15.0
  3: (y ^ 2)
]
Algebraic:  List[
  1: (((x ^ 2) + (x0 <- (y ^ 2))) - 1)
  2: 15.0
  3: x0
]

Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Parsed:     ((x + y) + ((((x + y) * (4 + 5)) + ((x + y) * (4 + y))) + ((x + y) + (4 + y))))
Numeric:    ((x + y) + ((((x + y) * 9.0) + ((x + y) * (4 + y))) + ((x + y) + (4 + y))))
Algebraic:  ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))

So I don't know if that's useful for you, but it turns out to be useful for me. And this is the sort of thing that I would be very hesitant to tackle in C++ because various things that were supposed to be easy ended up being painful instead.


Edit: Here's an example of using this structure to print temporary assignments, just to demonstrate that this structure is perfectly okay for doing such things.

Code:

  def useTempVars(expr: Expr): Expr = {
    var n = 0
    def temp = { n += 1; "t"+n }
    def replaceTemp(e: Expr, exempt: Boolean = false): Expr = {
      def varify(x: Expr) = if (exempt) x else Var(temp,List(x))
      e match {
        case Var(t,List(e)) => Var(t,List(replaceTemp(e, exempt = true)))
        case Una(t,List(u)) => varify( Una(t, List(replaceTemp(u,false))) )
        case Bin(t,lr) => varify( Bin(t, lr.map(replaceTemp(_,false))) )
        case Nar(t,ls) => varify( Nar(t, ls.map(replaceTemp(_,false))) )
        case _ => e
      }
    }
    replaceTemp(expr)
  }
  def varCut(expr: Expr): Expr = expr match {
    case Var(t,_) => Var(t,Nil)
    case Una(t,List(u)) => Una(t,List(varCut(u)))
    case Bin(t,lr) => Bin(t, lr.map(varCut))
    case Nar(t,ls) => Nar(t, ls.map(varCut))
    case _ => expr
  }
  def getAssignments(expr: Expr): List[Expr] = {
    val children = expr.args.flatMap(getAssignments)
    expr match {
      case Var(t,List(e)) => children :+ expr
      case _ => children
    }
  }
  def listAssignments(expr: Expr): List[String] = {
    getAssignments(expr).collect(e => e match {
      case Var(t,List(v)) => t + " = " + varCut(v)
    }) :+ (expr.text + " is the answer")
  }

Selected results (from listAssignments(useTempVars(r)).foreach(printf(" %s\n",_))):

Start with: - + ^ x 2 ^ y 2 1
Assignments:
  t1 = (x ^ 2)
  t2 = (y ^ 2)
  t3 = (t1 + t2)
  t4 = (t3 - 1)
  t4 is the answer

Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Algebraic:  ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))
Assignments:
  x0 = (x + y)
  t1 = (x0 * 9.0)
  x1 = (4 + y)
  t2 = (x0 * x1)
  t3 = (t1 + t2)
  t4 = (x0 + x1)
  t5 = (t3 + t4)
  t6 = (x0 + t5)
  t6 is the answer

Second edit: finding dependencies is also not too bad.

Code:

  def directDepends(expr: Expr): Set[Expr] = expr match {
    case Var(t,_) => Set(expr)
    case _ => expr.args.flatMap(directDepends).toSet
  }
  def indirectDepends(expr: Expr) = {
    val depend = getAssignments(expr).map(e => 
      e -> e.args.flatMap(directDepends).toSet
    ).toMap
    val tagged = for ((k,v) <- depend) yield (k.text -> v.map(_.text))
    def percolate(tags: Map[String,Set[String]]): Option[Map[String,Set[String]]] = {
      val expand = for ((k,v) <- tags) yield (
        k -> (v union v.flatMap(x => tags.get(x).getOrElse(Set())))
      )
      if (tags.exists(kv => expand(kv._1) contains kv._1)) None  // Cyclic dependency!
      else if (tags == expand) Some(tags)
      else percolate(expand)
    }
    percolate(tagged)
  }
  def listDependents(expr: Expr): List[(String,String)] = {
    def sayNothing(s: String) = if (s=="") "nothing" else s
    val e = expr match {
      case Var(_,_) => expr
      case _ => Var("result",List(expr))
    }
    indirectDepends(e).map(_.toList.map(x =>
      (x._1, sayNothing(x._2.toList.sorted.mkString(" ")))
    )).getOrElse(List((e.text,"cyclic")))
  }

And if we add new test cases val test7 = "list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y" and val test8 = "list 2 x = y y = x" and show the answers with for ((v,d) <- listDependents(r)) println(" "+v+" requires "+d) we get (selected results):

Start with: - + ^ x 2 ^ y 2 1
Dependencies:
  result requires x y

Start with: list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y
Parsed:     List[
  1: (z <- (x ^ 2))
  2: ((z + (y ^ 2)) - 1)
  3: (w <- (z - y))
]
Dependencies:
  z requires x
  w requires x y z
  result requires w x y z

Start with: list 2 x = y y = x
Parsed:     List[
  1: (x <- y)
  2: (y <- x)
]
Dependencies:
  result requires cyclic

Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Algebraic:  ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))
Dependencies:
  x0 requires x y
  x1 requires y
  result requires x x0 x1 y

So I think that on top of this sort of structure, all of your individual requirements are met by blocks of one or two dozen lines of Scala code.


Edit: here's expression evaluation, if you're given a mapping from vars to values:

  def numericEvaluate(expr: Expr, initialValues: Map[String,Double]) = {
    val chain = new collection.mutable.ArrayBuffer[(String,Double)]
    val evaluated = new collection.mutable.HashMap[String,Double]
    def note(xv: (String,Double)) { chain += xv; evaluated += xv }
    evaluated ++= initialValues
    def substitute(expr: Expr): Expr = expr match {
      case Var(t,List(n @ Num(v,_))) => { note(t -> v.toDouble); n }
      case Var(t,_) if (evaluated contains t) => Num(evaluated(t).toString,Nil)
      case Var(t,ls) => Var(t,ls.map(substitute))
      case Una(t,List(u)) => Una(t,List(substitute(u)))
      case Bin(t,ls) => Bin(t,ls.map(substitute))
      case Nar(t,ls) => Nar(t,ls.map(substitute))
      case _ => expr
    }
    def recurse(e: Expr): Expr = {
      val sub = numericSimplify(substitute(e))
      if (sub == e) e else recurse(sub)
    }
    (recurse(expr), chain.toList)
  }

and it's used like so in the testing routine:

        val (num,ops) = numericEvaluate(r,Map("x"->3,"y"->1.5))
        println("Evaluated:")
        for ((v,n) <- ops) println("  "+v+" = "+n)
        println("  result = " + num)

giving results like these (with input of x = 3 and y = 1.5):

Start with: list 3 - + ^ x 2 ^ y 2 1 + * 2 3 ^ 3 2 ^ y 2
Algebraic:  List[
  1: (((x ^ 2) + (x0 <- (y ^ 2))) - 1)
  2: 15.0
  3: x0
]
Evaluated:
  x0 = 2.25
  result = List[
  1: 10.25
  2: 15.0
  3: 2.25
]

Start with: list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y
Algebraic:  List[
  1: (z <- (x ^ 2))
  2: ((z + (y ^ 2)) - 1)
  3: (w <- (z - y))
]
Evaluated:
  z = 9.0
  w = 7.5
  result = List[
  1: 9.0
  2: 10.25
  3: 7.5
]

The other challenge--picking out the vars that haven't already been used--is just set subtraction off of the dependencies result list. diff is the name of the set subtraction method.


The problem consists of two subproblems: parsing and symbolic manipulation. It seems to me the answer boils down to two possible solutions.

One is to implement everything from scratch: "I do recommend creating the full expression tree if you want to retain maximum flexibility for handling tricky cases." - proposed by Rex. As Sven points out: "any of the high-level languages you listed are almost equally suited for the task," however "Python (or any of the high-level languages you listed) won't take away the complexity of the problem."
I have received very nice solutions in Scala (many thanks for Rex and Daniel), a nice little example in Python (from Sven). However, I am still interested in Lisp, Haskell or Erlang solutions.

The other solution is to use some existing library/software for the task, with all the implied pros and cons. Candidates are Maxima (Common Lisp), SymPy (Python, proposed by payne) and GiNaC (C++).

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜