开发者

Strassen algorithm not the fastest?

I copied strassen's algorithm from somewhere and then executed it. Here is the output

n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms

where strassen1 is a dynamic approach, strassen2 for cache and classical is the old matrix multiplication. This means that our old and easy classical one is the best. Is this true or i am wrong somewhere? Here's the code in Java.

import java.util.Random;

class TestIntMatrixMultiplication {

    public static void main (String...args) throws Exception {
        final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
        final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
        final Random random = new Random(seed);

        int[][] a, b, c;

        a = new int[n][n];
        b = new int[n][n];
        c = new int[n][n];

        for(int i=0; i<n; i++) {
            for(int j=0; j<n; j++) {
                a[i][j] = random.nextInt(100);
                b[i][j] = random.nextInt(100);
            }
        }



        System.out.println("n = " + n);

        if (a.length < 64) {
            System.out.println("A");
            dumpMatrix(a);
            System.out.println("B");
            dumpMatrix(b);
            System.out.println("classic");
            Classical.mult(c, a, b);
            dumpMatrix(c);
            System.out.println("strassen");
            strassen2.mult(c, a, b);
            dumpMatrix(c);

            return;
        }

        for (int i = 0; i <3; ++i) {
            timeMultiplies1(a, b, c);
            if (n <= 256)
                timeMultiplies2( a, b, c);
            timeMultiplies3( a, b, c);
        }
    }

    static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        Classical.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("classical took " + (finish - start) + "ms");
    }
    static void timeMultiplies2(int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen1.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen 1 took " + (finish - start) + "ms");
    }
    static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen2.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen2 took " + (finish - start) + "ms");
    }

    static void dumpMatrix (int[][] m) {
        for (int[] row : m) {
            System.out.print("[\t");
            for (int val : row) {
                System.out.print(val);
                System.out.print('\t');
            }
            System.out.println(']');
        }
    }
}

class strassen1{

    public String getName () {
        return "Strassen(dynamic)";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        return strassenMatrixMultiplication(a, b);
    }

    public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        if(n == 1) {
            result[0][0] = A[0][0] * B[0][0];
        } else {
            int [][] A11 = new int[n/2][n/2];
            int [][] A12 = new int[n/2][n/2];
            int [][] A21 = new int[n/2][n/2];
            int [][] A22 = new int[n/2][n/2];

            int [][] B11 = new int[n/2][n/2];
            int [][] B12 = new int[n/2][n/2];
            int [][] B21 = new int[n/2][n/2];
            int [][] B22 = new int[n/2][n/2];

            divideArray(A, A11, 0 , 0);
            divideArray(A, A12, 0 , n/2);
            divideArray(A, A21, n/2, 0);
            divideArray(A, A22, n/2, n/2);

            divideArray(B, B11, 0 , 0);
            divideArray(B, B12, 0 , n/2);
            divideArray(B, B21, n/2, 0);
            divideArray(B, B22, n/2, n/2);

            int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
            int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
            int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
            int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
            int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
            int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
            int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));

            int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
            int [][] C12 = addMatrices(P3, P5);
            int [][] C21 = addMatrices(P2, P4);
            int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);

            copySubArray(C11, result, 0 , 0);
            copySubArray(C12, result, 0 , n/2);
            copySubArray(C21, result, n/2, 0);
            copySubArray(C22, result, n/2, n/2);
        }

        return result;
    }

    public static int [][] addMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
        for(int j=0; j<n; j++)
        result[i][j] = A[i][j] + B[i][j];

        return result;
    }

    public static int [][] subtractMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                result[i][j] = A[i][j] - B[i][j];

        return result;
    }

    public static void divideArray(int[][] parent, int[][] child, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                child[i1][j1] = parent[i2][j2];
    }

    public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                parent[i2][j2] = child[i1][j1];
    }
}
class strassen2{

    public String getName () {
        return "Strassen(cached)";
    }

    static int [][] p1;
    static int [][] p2;
    static int [][] p3;
    static int [][] p4;
    static int [][] p5;
    static int [][] p6;
    static int [][] p7;
    static int [][] t0;
    static int [][] t1;

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        final int n = c.length;

        if (p1 == null || p1.length < n) {
            p1 = new int[n/2][n-1];
            p2 = new int[n/2][n-1];
            p3 = new int[n/2][n-1];
            p4 = new int[n/2][n-1];
            p5 = new int[n/2][n-1];
            p6 = new int[n/2][n-1];
            p7 = new int[n/2][n-1];
            t0 = new int[n/2][n-1];
            t1 = new int[n/2][n-1];
        }

        mult(c, a, b, 0, 0, n, 0);

        return c;
    }

    public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) {
        if(n == 1) {
            c[i0][j0] = a[i0][j0] * b[i0][j0];
        } else {
            final int nBy2 = n/2;

            final int i1 = i0 + nBy2;
            final int j1 = j0 + nBy2;

            // offset applied to 'p' j index so recursive calls don't overwrite data
            final int jp0 = offs;
            final int jp1 = nBy2 + offs;

            // P1 <- (A11 + A22)(B11 + B22)
            //  T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P2 <- (A21 + A22)B11
            //  T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0];
                    }
            }

            mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P3 <- A11(B12 - B22)
            //  T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
                }
            }

            m开发者_如何学Cult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P4 <- A22(B21 - B11)
            //  T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
                }
            }

            mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P5 <- (A11 + A12) B22
            //  T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j1];
                }
            }

            mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P6 <- (A21 - A11)(B11 - B12)
            //  T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
                }
            }

            mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P7 <- (A12 - A22)(B21 + B22)
            //  T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // combine
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    // C11 = P1 + P4 - P5 + P7;
                    c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
                    // C12 = P3 + P5;
                    c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
                    // C21 = P2 + P4;
                    c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
                    // C22 = P1 + P3 - P2 + P6;
                    c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
                }
            }
        }
    }

    void dumpInternal () {
        System.out.println("P1");
        TestIntMatrixMultiplication.dumpMatrix(p1);
        System.out.println("P2");
        TestIntMatrixMultiplication.dumpMatrix(p2);
        System.out.println("P3");
        TestIntMatrixMultiplication.dumpMatrix(p3);
        System.out.println("P4");
        TestIntMatrixMultiplication.dumpMatrix(p4);
        System.out.println("P5");
        TestIntMatrixMultiplication.dumpMatrix(p5);
        System.out.println("P6");
        TestIntMatrixMultiplication.dumpMatrix(p6);
        System.out.println("P7");
        TestIntMatrixMultiplication.dumpMatrix(p7);
        System.out.println("T0");
        TestIntMatrixMultiplication.dumpMatrix(t0);
        System.out.println("T1");
        TestIntMatrixMultiplication.dumpMatrix(t1);
    }
}


class Classical{
    public String getName () {
        return "classic";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        int n = a.length;

        for(int i=0; i<n; i++) {
            final int[] a_i = a[i];
            final int[] c_i = c[i];

            for(int j=0; j<n; j++) {
                int sum = 0;

                for(int k=0; k<n; k++) {
                    sum += a_i[k] * b[k][j];
                }

                c_i[j] = sum;
            }
        }

        return c;
    }
}


Issues I see:

1)Your Strassen multiply is dynamically allocating memory all the time. This is going to kill performance.

2)Your Strassen multiply should switch over to conventional multiply for small sizes rather than being recursive all the way down (though this optimization sort of invalidates your test).

3)You're matrix size may simply be too small to see the difference.

You should do comparisons with several different sizes. Perhaps 256, 512, 1024, 2048, 4096, 8192... Then plot the times and look at the trends. You will probably want matrix size on a log scale if it's all powers of 2.

Strassen is only faster for large N. How large will depend a lot on the implementation. What you have done for classical is only a basic implementation and is not optimal on a modern machine either.


Implementation questions aside, I think you're misunderstanding the algorithm's performance. Like phkahler said, your expectations are a little off for the performance of the algorithm. Divide-and-conquer algorithms work well for large inputs because they recursively break the problem into sub-problems which can be solved more quickly.

However, the overhead associated with this splitting action can cause the algorithm to run (sometimes much) slower for small or even medium-sized inputs. Typically, the theoretical analysis of an algorithm like Strassen will include a so-called "breakpoint" calculation. This is the input size where the overhead of splitting becomes preferable to a naive technique.

Your code needs to include a check on the size of the input that switches to the naive technique at the breakpoint.


Write down what the Strassen algorithm does for a 2 x 2 matrix. Count the operations. The number is absolutely ridiculous. It's stupid to use Strassen's method for a 2x2 matrix. Same for a 3 x 3, or 4 x 4, matrix and probably quite a way up.

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜