![](/img/trans.png)
[英]Vector-Matrix-Multiplication in java with parallel colt
[英]Java 8 matrix * vector multiplication
我想知道在Java 8中使用流來執行以下操作是否有更簡潔的方法:
public static double[] multiply(double[][] matrix, double[] vector) {
int rows = matrix.length;
int columns = matrix[0].length;
double[] result = new double[rows];
for (int row = 0; row < rows; row++) {
double sum = 0;
for (int column = 0; column < columns; column++) {
sum += matrix[row][column]
* vector[column];
}
result[row] = sum;
}
return result;
}
編輯。 我收到了一個非常好的答案,但是性能比舊的實現慢了大約10倍,所以我在這里添加測試代碼以防有人想要調查它:
@Test
public void profile() {
long start;
long stop;
int tenmillion = 10000000;
double[] vector = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
double[][] matrix = new double[tenmillion][10];
for (int i = 0; i < tenmillion; i++) {
matrix[i] = vector.clone();
}
start = System.currentTimeMillis();
multiply(matrix, vector);
stop = System.currentTimeMillis();
}
使用Stream的直接方法如下:
public static double[] multiply(double[][] matrix, double[] vector) {
return Arrays.stream(matrix)
.mapToDouble(row ->
IntStream.range(0, row.length)
.mapToDouble(col -> row[col] * vector[col])
.sum()
).toArray();
}
這將創建矩陣的每一行的Stream<double[]>
( Stream<double[]>
),然后將每一行映射到使用vector
數組計算乘積的double值。
我們必須在索引上使用Stream來計算產品,因為遺憾的是沒有內置工具可以將兩個Streot壓縮在一起。
衡量性能的方式對於衡量性能並不是非常可靠,手動編寫微基准通常不是一個好主意。 例如,在編譯代碼時,JVM可能會選擇更改執行順序,並且可能無法將啟動和停止變量分配到您希望分配的位置,從而在測量中產生意外結果。 預熱JVM以及讓JIT編譯器進行所有優化也非常重要。 GC還可以在引入應用程序吞吐量和響應時間的變化方面發揮重要作用。 我強烈建議使用JMH和Caliper等專用工具進行微基准測試。
我還用JVM預熱,隨機數據集和更多迭代次數編寫了一些基准測試代碼。 事實證明,Java 8流提供了更好的結果。
/**
*
*/
public class MatrixMultiplicationBenchmark {
private static AtomicLong start = new AtomicLong();
private static AtomicLong stop = new AtomicLong();
private static Random random = new Random();
/**
* Main method that warms-up each implementation and then runs the benchmark.
*
* @param args main class args
*/
public static void main(String[] args) {
// Warming up with more iterations and smaller data set
System.out.println("Warming up...");
IntStream.range(0, 10_000_000).forEach(i -> run(10, MatrixMultiplicationBenchmark::multiplyWithStreams));
IntStream.range(0, 10_000_000).forEach(i -> run(10, MatrixMultiplicationBenchmark::multiplyWithForLoops));
// Running with less iterations and larger data set
startWatch("Running MatrixMultiplicationBenchmark::multiplyWithForLoops...");
IntStream.range(0, 10).forEach(i -> run(10_000_000, MatrixMultiplicationBenchmark::multiplyWithForLoops));
endWatch("MatrixMultiplicationBenchmark::multiplyWithForLoops");
startWatch("Running MatrixMultiplicationBenchmark::multiplyWithStreams...");
IntStream.range(0, 10).forEach(i -> run(10_000_000, MatrixMultiplicationBenchmark::multiplyWithStreams));
endWatch("MatrixMultiplicationBenchmark::multiplyWithStreams");
}
/**
* Creates the random matrix and vector and applies them in the given implementation as BiFunction object.
*
* @param multiplyImpl implementation to use.
*/
public static void run(int size, BiFunction<double[][], double[], double[]> multiplyImpl) {
// creating random matrix and vector
double[][] matrix = new double[size][10];
double[] vector = random.doubles(10, 0.0, 10.0).toArray();
IntStream.range(0, size).forEach(i -> matrix[i] = random.doubles(10, 0.0, 10.0).toArray());
// applying matrix and vector to the given implementation. Returned value should not be ignored in test cases.
double[] result = multiplyImpl.apply(matrix, vector);
}
/**
* Multiplies the given vector and matrix using Java 8 streams.
*
* @param matrix the matrix
* @param vector the vector to multiply
*
* @return result after multiplication.
*/
public static double[] multiplyWithStreams(final double[][] matrix, final double[] vector) {
final int rows = matrix.length;
final int columns = matrix[0].length;
return IntStream.range(0, rows)
.mapToDouble(row -> IntStream.range(0, columns)
.mapToDouble(col -> matrix[row][col] * vector[col])
.sum()).toArray();
}
/**
* Multiplies the given vector and matrix using vanilla for loops.
*
* @param matrix the matrix
* @param vector the vector to multiply
*
* @return result after multiplication.
*/
public static double[] multiplyWithForLoops(double[][] matrix, double[] vector) {
int rows = matrix.length;
int columns = matrix[0].length;
double[] result = new double[rows];
for (int row = 0; row < rows; row++) {
double sum = 0;
for (int column = 0; column < columns; column++) {
sum += matrix[row][column] * vector[column];
}
result[row] = sum;
}
return result;
}
private static void startWatch(String label) {
System.out.println(label);
start.set(System.currentTimeMillis());
}
private static void endWatch(String label) {
stop.set(System.currentTimeMillis());
System.out.println(label + " took " + ((stop.longValue() - start.longValue()) / 1000) + "s");
}
}
這是輸出
Warming up...
Running MatrixMultiplicationBenchmark::multiplyWithForLoops...
MatrixMultiplicationBenchmark::multiplyWithForLoops took 100s
Running MatrixMultiplicationBenchmark::multiplyWithStreams...
MatrixMultiplicationBenchmark::multiplyWithStreams took 89s
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.