Handling StackOverflow in Java for Trampoline
I would like to implement a trampoline in java by returning a thunk whenever I hit a StackOverflowError. Are there any guarantees about the StackOverflowError, like, if the only thing I do after the StackOverflowError is creating objects on the heap and returning from functions, I will be fine?
If the above sounds vague, I have added some code for computing even/odd in a tail-recursive manner in continuation passing style, returning a delayed thunk whenever the stack flows over. The code works on my machine, but does Java guarantee that it will always work?
public class CPS {
public static class Thunk {
final Object r;
final Continuation c;
final boolean isDelayed;
public Object force() {
Thunk t = this;
while (t.isDelayed)
t = t.compute();
return t.r;
}
public Thunk compute() {
return this;
}
public Thunk(Object answer) {
isDelayed = false;
r = answer;
c = null;
}
public Thunk(Object intermediate, Continuation cont) {
r = intermediate;
c = cont;
isDelayed = true;
}
}
public static class Continuation {
public Thunk apply(Object result) {
return new Thunk(result);
}
}
public static Thunk even(final int n, final Continuation c) {
try {
if (n == 0) return c.apply(true);
else return odd(n-1, c);
} catch (StackOverflowError x) {
return new Thunk(n, c) {
public Thunk compute() 开发者_高级运维{
return even(((Integer)n).intValue(), c);
}
};
}
}
public static Thunk odd(final int n, final Continuation c) {
try {
if (n == 0) return c.apply(false);
else return even(n-1, c);
} catch (StackOverflowError x) {
return new Thunk(n, c) {
public Thunk compute() {
return odd(((Integer)n).intValue(), c);
}
};
}
}
public static void main(String args[]) {
System.out.println(even(100001, new Continuation()).force());
}
}
I tried the following implementation possibilities: A) With thunks (see code CPS below) B) Without thunks as suggested by chris (see code CPS2 below) C) With thunks with the stack overflow replaced by a depth check (see code CPS3 below)
In each case I checked if 100,000,000 is an even number. This check lasted A) about 2 seconds B) about 17 seconds C) about 0.2 seconds
So returning from a long chain of functions is match faster than throwing an exception that unwinds that chain. Also, instead of waiting for a stack overflow, it is much faster to just record the recursion depth and unwind at depth 1000.
Code for CPS:
public class CPS {
public static class Thunk {
final Object r;
final boolean isDelayed;
public Object force() {
Thunk t = this;
while (t.isDelayed)
t = t.compute();
return t.r;
}
public Thunk compute() {
return this;
}
public Thunk(Object answer) {
isDelayed = false;
r = answer;
}
public Thunk() {
isDelayed = true;
r = null;
}
}
public static class Continuation {
public Thunk apply(Object result) {
return new Thunk(result);
}
}
public static Thunk even(final int n, final Continuation c) {
try {
if (n == 0) return c.apply(true);
else return odd(n-1, c);
} catch (StackOverflowError x) {
return new Thunk() {
public Thunk compute() {
return even(n, c);
}
};
}
}
public static Thunk odd(final int n, final Continuation c) {
try {
if (n == 0) return c.apply(false);
else return even(n-1, c);
} catch (StackOverflowError x) {
return new Thunk() {
public Thunk compute() {
return odd(n, c);
}
};
}
}
public static void main(String args[]) {
long time1 = System.currentTimeMillis();
Object b = even(100000000, new Continuation()).force();
long time2 = System.currentTimeMillis();
System.out.println("time = "+(time2-time1)+", result = "+b);
}
}
Code for CPS2:
public class CPS2 {
public abstract static class Unwind extends RuntimeException {
public abstract Object compute();
public Object force() {
Unwind w = this;
do {
try {
return w.compute();
} catch (Unwind unwind) {
w = unwind;
}
} while (true);
}
}
public static class Continuation {
public Object apply(Object result) {
return result;
}
}
public static Object even(final int n, final Continuation c) {
try {
if (n == 0) return c.apply(true);
else return odd(n-1, c);
} catch (StackOverflowError x) {
throw new Unwind() {
public Object compute() {
return even(n, c);
}
};
}
}
public static Object odd(final int n, final Continuation c) {
try {
if (n == 0) return c.apply(false);
else return even(n-1, c);
} catch (StackOverflowError x) {
return new Unwind() {
public Object compute() {
return odd(n, c);
}
};
}
}
public static void main(String args[]) {
long time1 = System.currentTimeMillis();
Unwind w = new Unwind() {
public Object compute() {
return even(100000000, new Continuation());
}
};
Object b = w.force();
long time2 = System.currentTimeMillis();
System.out.println("time = "+(time2-time1)+", result = "+b);
}
}
Code for CPS3:
public class CPS3 {
public static class Thunk {
final Object r;
final boolean isDelayed;
public Object force() {
Thunk t = this;
while (t.isDelayed)
t = t.compute();
return t.r;
}
public Thunk compute() {
return this;
}
public Thunk(Object answer) {
isDelayed = false;
r = answer;
}
public Thunk() {
isDelayed = true;
r = null;
}
}
public static class Continuation {
public Thunk apply(Object result) {
return new Thunk(result);
}
}
public static Thunk even(final int n, final Continuation c, final int depth) {
if (depth >= 1000) {
return new Thunk() {
public Thunk compute() {
return even(n, c, 0);
}
};
}
if (n == 0) return c.apply(true);
else return odd(n-1, c, depth+1);
}
public static Thunk odd(final int n, final Continuation c, final int depth) {
if (depth >= 1000) {
return new Thunk() {
public Thunk compute() {
return odd(n, c, 0);
}
};
}
if (n == 0) return c.apply(false);
else return even(n-1, c, depth+1);
}
public static void main(String args[]) {
long time1 = System.currentTimeMillis();
Object b = even(100000000, new Continuation(), 0).force();
long time2 = System.currentTimeMillis();
System.out.println("time = "+(time2-time1)+", result = "+b);
}
}
That's an interesting way to jump up the stack. It seems to work, but is probably slower than the usual way to implement this technique, which is to throw an exception that is caught $BIGNUM
layers up the call stack.
精彩评论