Java中的Fork连接矩阵乘法

我正在对Java 7中的fork / join框架进行一些性能研究。为了改进测试结果,我想在测试期间使用不同的递归算法。 其中一个是乘法矩阵。

我从Doug Lea的网站()下载了以下示例:

public class MatrixMultiply { static final int DEFAULT_GRANULARITY = 16; /** The quadrant size at which to stop recursing down * and instead directly multiply the matrices. * Must be a power of two. Minimum value is 2. **/ static int granularity = DEFAULT_GRANULARITY; public static void main(String[] args) { final String usage = "Usage: java MatrixMultiply   [] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16"; try { int procs; int n; try { procs = Integer.parseInt(args[0]); n = Integer.parseInt(args[1]); if (args.length > 2) granularity = Integer.parseInt(args[2]); } catch (Exception e) { System.out.println(usage); return; } if ( ((n & (n - 1)) != 0) || ((granularity & (granularity - 1)) != 0) || granularity < 2) { System.out.println(usage); return; } float[][] a = new float[n][n]; float[][] b = new float[n][n]; float[][] c = new float[n][n]; init(a, b, n); FJTaskRunnerGroup g = new FJTaskRunnerGroup(procs); g.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n)); g.stats(); // check(c, n); } catch (InterruptedException ex) {} } // To simplify checking, fill with all 1's. Answer should be all n's. static void init(float[][] a, float[][] b, int n) { for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { a[i][j] = 1.0F; b[i][j] = 1.0F; } } } static void check(float[][] c, int n) { for (int i = 0; i < n; i++ ) { for (int j = 0; j < n; j++ ) { if (c[i][j] != n) { throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]); } } } } /** * Multiply matrices AxB by dividing into quadrants, using algorithm: * 
 * A x B * * A11 | A12 B11 | B12 A11*B11 | A11*B12 A12*B21 | A12*B22 * |----+----| x |----+----| = |--------+--------| + |---------+-------| * A21 | A22 B21 | B21 A21*B11 | A21*B21 A22*B21 | A22*B22 * 

*/ static class Multiplier extends FJTask { final float[][] A; // Matrix A final int aRow; // first row of current quadrant of A final int aCol; // first column of current quadrant of A final float[][] B; // Similarly for B final int bRow; final int bCol; final float[][] C; // Similarly for result matrix C final int cRow; final int cCol; final int size; // number of elements in current quadrant Multiplier(float[][] A, int aRow, int aCol, float[][] B, int bRow, int bCol, float[][] C, int cRow, int cCol, int size) { this.A = A; this.aRow = aRow; this.aCol = aCol; this.B = B; this.bRow = bRow; this.bCol = bCol; this.C = C; this.cRow = cRow; this.cCol = cCol; this.size = size; } public void run() { if (size <= granularity) { multiplyStride2(); } else { int h = size / 2; coInvoke(new FJTask[] { seq(new Multiplier(A, aRow, aCol, // A11 B, bRow, bCol, // B11 C, cRow, cCol, // C11 h), new Multiplier(A, aRow, aCol+h, // A12 B, bRow+h, bCol, // B21 C, cRow, cCol, // C11 h)), seq(new Multiplier(A, aRow, aCol, // A11 B, bRow, bCol+h, // B12 C, cRow, cCol+h, // C12 h), new Multiplier(A, aRow, aCol+h, // A12 B, bRow+h, bCol+h, // B22 C, cRow, cCol+h, // C12 h)), seq(new Multiplier(A, aRow+h, aCol, // A21 B, bRow, bCol, // B11 C, cRow+h, cCol, // C21 h), new Multiplier(A, aRow+h, aCol+h, // A22 B, bRow+h, bCol, // B21 C, cRow+h, cCol, // C21 h)), seq(new Multiplier(A, aRow+h, aCol, // A21 B, bRow, bCol+h, // B12 C, cRow+h, cCol+h, // C22 h), new Multiplier(A, aRow+h, aCol+h, // A22 B, bRow+h, bCol+h, // B22 C, cRow+h, cCol+h, // C22 h)) }); } } /** * Version of matrix multiplication that steps 2 rows and columns * at a time. Adapted from Cilk demos. * Note that the results are added into C, not just set into C. * This works well here because Java array elements * are created with all zero values. **/ void multiplyStride2() { for (int j = 0; j < size; j+=2) { for (int i = 0; i < size; i +=2) { float[] a0 = A[aRow+i]; float[] a1 = A[aRow+i+1]; float s00 = 0.0F; float s01 = 0.0F; float s10 = 0.0F; float s11 = 0.0F; for (int k = 0; k < size; k+=2) { float[] b0 = B[bRow+k]; s00 += a0[aCol+k] * b0[bCol+j]; s10 += a1[aCol+k] * b0[bCol+j]; s01 += a0[aCol+k] * b0[bCol+j+1]; s11 += a1[aCol+k] * b0[bCol+j+1]; float[] b1 = B[bRow+k+1]; s00 += a0[aCol+k+1] * b1[bCol+j]; s10 += a1[aCol+k+1] * b1[bCol+j]; s01 += a0[aCol+k+1] * b1[bCol+j+1]; s11 += a1[aCol+k+1] * b1[bCol+j+1]; } C[cRow+i] [cCol+j] += s00; C[cRow+i] [cCol+j+1] += s01; C[cRow+i+1][cCol+j] += s10; C[cRow+i+1][cCol+j+1] += s11; } } } } }

此代码是为旧版本的fork / join框架编写的。 所以我必须改写它。 我重写的代码实现了我自己的界面,如下所示:

 public class Java7MatrixMultiply implements Algorithm { private static final int SIZE = 32; private static final int THRESHOLD = 8; private float[][] a = new float[SIZE][SIZE]; private float[][] b = new float[SIZE][SIZE]; private float[][] c = new float[SIZE][SIZE]; ForkJoinPool forkJoinPool; @Override public void initialize() { init(a, b, SIZE); } @Override public void execute() { MatrixMultiplyTask mainTask = new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE); forkJoinPool = new ForkJoinPool(); forkJoinPool.invoke(mainTask); System.out.println("Terminated!"); } @Override public void printResult() { check(c, SIZE); for (int i = 0; i < SIZE; i++) { for (int j = 0; j < SIZE; j++) { System.out.print(c[i][j] + " "); } System.out.println(); } } // To simplify checking, fill with all 1's. Answer should be all n's. static void init(float[][] a, float[][] b, int n) { for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { a[i][j] = 1.0F; b[i][j] = 1.0F; } } } static void check(float[][] c, int n) { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { if (c[i][j] != n) { //throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]); System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]); } } } } private class MatrixMultiplyTask extends RecursiveAction { private final float[][] A; // Matrix A private final int aRow; // first row of current quadrant of A private final int aCol; // first column of current quadrant of A private final float[][] B; // Similarly for B private final int bRow; private final int bCol; private final float[][] C; // Similarly for result matrix C private final int cRow; private final int cCol; private final int size; MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B, int bRow, int bCol, float[][] C, int cRow, int cCol, int size) { this.A = A; this.aRow = aRow; this.aCol = aCol; this.B = B; this.bRow = bRow; this.bCol = bCol; this.C = C; this.cRow = cRow; this.cCol = cCol; this.size = size; } @Override protected void compute() { if (size <= THRESHOLD) { multiplyStride2(); } else { int h = size / 2; invokeAll(new MatrixMultiplyTask[] { new MatrixMultiplyTask(A, aRow, aCol, // A11 B, bRow, bCol, // B11 C, cRow, cCol, // C11 h), new MatrixMultiplyTask(A, aRow, aCol + h, // A12 B, bRow + h, bCol, // B21 C, cRow, cCol, // C11 h), new MatrixMultiplyTask(A, aRow, aCol, // A11 B, bRow, bCol + h, // B12 C, cRow, cCol + h, // C12 h), new MatrixMultiplyTask(A, aRow, aCol + h, // A12 B, bRow + h, bCol + h, // B22 C, cRow, cCol + h, // C12 h), new MatrixMultiplyTask(A, aRow + h, aCol, // A21 B, bRow, bCol, // B11 C, cRow + h, cCol, // C21 h), new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22 B, bRow + h, bCol, // B21 C, cRow + h, cCol, // C21 h), new MatrixMultiplyTask(A, aRow + h, aCol, // A21 B, bRow, bCol + h, // B12 C, cRow + h, cCol + h, // C22 h), new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22 B, bRow + h, bCol + h, // B22 C, cRow + h, cCol + h, // C22 h) }); } } /** * Version of matrix multiplication that steps 2 rows and columns at a * time. Adapted from Cilk demos. Note that the results are added into * C, not just set into C. This works well here because Java array * elements are created with all zero values. **/ void multiplyStride2() { for (int j = 0; j < size; j += 2) { for (int i = 0; i < size; i += 2) { float[] a0 = A[aRow + i]; float[] a1 = A[aRow + i + 1]; float s00 = 0.0F; float s01 = 0.0F; float s10 = 0.0F; float s11 = 0.0F; for (int k = 0; k < size; k += 2) { float[] b0 = B[bRow + k]; s00 += a0[aCol + k] * b0[bCol + j]; s10 += a1[aCol + k] * b0[bCol + j]; s01 += a0[aCol + k] * b0[bCol + j + 1]; s11 += a1[aCol + k] * b0[bCol + j + 1]; float[] b1 = B[bRow + k + 1]; s00 += a0[aCol + k + 1] * b1[bCol + j]; s10 += a1[aCol + k + 1] * b1[bCol + j]; s01 += a0[aCol + k + 1] * b1[bCol + j + 1]; s11 += a1[aCol + k + 1] * b1[bCol + j + 1]; } C[cRow + i][cCol + j] += s00; C[cRow + i][cCol + j + 1] += s01; C[cRow + i + 1][cCol + j] += s10; C[cRow + i + 1][cCol + j + 1] += s11; } } } } } 

有时我的计算无法通​​过检查。 Matrix的某些字段具有预期的不同值。 这些不一致是随机的,并不总是发生。 我怀疑计算方法出了问题,因为我不得不重写使用Seq类的部分。 与invokeAll()方法不同,Seq klass按顺序执行任务。 在当前版本的fork / join框架中,该类不再存在。 我对矩阵乘法算法不太熟悉,所以很难看出出了什么问题。 有什么建议么?

正如您已经注意到的那样,顺序执行属于同一象限的子任务对于此算法很重要。 因此,您需要实现自己的seq()函数,例如,如下所示,并在原始代码中使用它:

 public ForkJoinTask seq(final ForkJoinTask a, final ForkJoinTask b) { return adapt(new Runnable() { public void run() { a.invoke(); b.invoke(); } }); } 

你在C[cRow + i][cCol + j] += s00;中累积结果C[cRow + i][cCol + j] += s00; 等等。 这不是线程安全操作,因此您必须同步行或确保只有一个任务更新单元格。 如果没有这个,你会发现随机单元设置不正确。

我会检查你得到正确的答案与并发1。

BTW: float可能不是这里的最佳选择。 它具有相当低的精度位数和重型矩阵运算(我假设你正在做或者使用多个线程没有多大意义)舍入误差可能会消耗大部分或全部精度。 我建议考虑double

例如, float有大约7位数的精度,一条经验法则是误差与计算次数成正比。 因此,对于1K x 1K矩阵,您可能会有4位精度。 对于10K x 10K,您可能只有三个。 double有16位精度,意味着在10K x 10K变换后你可能有12位数的精度。