简体   繁体   中英

Multi-threaded matrix multiplication

I've coded a multi-threaded matrix multiplication. I believe my approach is right, but I'm not 100% sure. In respect to the threads, I don't understand why I can't just run a (new MatrixThread(...)).start() instead of using an ExecutorService .

Additionally, when I benchmark the multithreaded approach versus the classical approach, the classical is much faster...

What am I doing wrong?

Matrix Class:

import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

class Matrix
{
   private int dimension;
   private int[][] template;

   public Matrix(int dimension)
   {
      this.template = new int[dimension][dimension];
      this.dimension = template.length;
   }

   public Matrix(int[][] array) 
   {
      this.dimension = array.length;
      this.template = array;      
   }

   public int getMatrixDimension() { return this.dimension; }

   public int[][] getArray() { return this.template; }

   public void fillMatrix()
   {
      Random randomNumber = new Random();
      for(int i = 0; i < dimension; i++)
      {
         for(int j = 0; j < dimension; j++)
         {
            template[i][j] = randomNumber.nextInt(10) + 1;
         }
      }
   }

   @Override
   public String toString()
   {
      String retString = "";
      for(int i = 0; i < this.getMatrixDimension(); i++)
      {
         for(int j = 0; j < this.getMatrixDimension(); j++)
         {
            retString += " " + this.getArray()[i][j];
         }
         retString += "\n";
      }
      return retString;
   }

   public static Matrix classicalMultiplication(Matrix a, Matrix b)
   {      
      int[][] result = new int[a.dimension][b.dimension];
      for(int i = 0; i < a.dimension; i++)
      {
         for(int j = 0; j < b.dimension; j++)
         {
            for(int k = 0; k < b.dimension; k++)
            {
               result[i][j] += a.template[i][k] * b.template[k][j];
            }
         }
      }
      return new Matrix(result);
   }

   public Matrix multiply(Matrix multiplier) throws InterruptedException
   {
      Matrix result = new Matrix(dimension);
      ExecutorService es = Executors.newFixedThreadPool(dimension*dimension);
      for(int currRow = 0; currRow < multiplier.dimension; currRow++)
      {
         for(int currCol = 0; currCol < multiplier.dimension; currCol++)
         {            
            //(new MatrixThread(this, multiplier, currRow, currCol, result)).start();            
            es.execute(new MatrixThread(this, multiplier, currRow, currCol, result));
         }
      }
      es.shutdown();
      es.awaitTermination(2, TimeUnit.DAYS);
      return result;
   }

   private class MatrixThread extends Thread
   {
      private Matrix a, b, result;
      private int row, col;      

      private MatrixThread(Matrix a, Matrix b, int row, int col, Matrix result)
      {         
         this.a = a;
         this.b = b;
         this.row = row;
         this.col = col;
         this.result = result;
      }

      @Override
      public void run()
      {
         int cellResult = 0;
         for (int i = 0; i < a.getMatrixDimension(); i++)
            cellResult += a.template[row][i] * b.template[i][col];

         result.template[row][col] = cellResult;
      }
   }
} 

Main class:

import java.util.Scanner;

public class MatrixDriver
{
   private static final Scanner kb = new Scanner(System.in);

   public static void main(String[] args) throws InterruptedException
   {      
      Matrix first, second;
      long timeLastChanged,timeNow;
      double elapsedTime;

      System.out.print("Enter value of n (must be a power of 2):");
      int n = kb.nextInt();

      first = new Matrix(n);
      first.fillMatrix();      
      second = new Matrix(n);
      second.fillMatrix();

      timeLastChanged = System.currentTimeMillis();
      //System.out.println("Product of the two using threads:\n" +
                                                        first.multiply(second);
      timeNow = System.currentTimeMillis();
      elapsedTime = (timeNow - timeLastChanged)/1000.0;
      System.out.println("Threaded took "+elapsedTime+" seconds");

      timeLastChanged = System.currentTimeMillis();
      //System.out.println("Product of the two using classical:\n" +
                                  Matrix.classicalMultiplication(first,second);
      timeNow = System.currentTimeMillis();
      elapsedTime = (timeNow - timeLastChanged)/1000.0;
      System.out.println("Classical took "+elapsedTime+" seconds");
   }
} 

PS Please let me know if any further clarification is needed.

There is a bunch of overhead involved in creating threads, even when using an ExecutorService. I suspect the reason why you're multithreaded approach is so slow is that you're spending 99% creating a new thread and only 1%, or less, doing the actual math.

Typically, to solve this problem you'd batch a whole bunch of operations together and run those on a single thread. I'm not 100% how to do that in this case, but I suggest breaking your matrix into smaller chunks (say, 10 smaller matrices) and run those on threads, instead of running each cell in its own thread.

You're creating a lot of threads. Not only is it expensive to create threads, but for a CPU bound application, you don't want more threads than you have available processors (if you do, you have to spend processing power switching between threads, which also is likely to cause cache misses which are very expensive).

It's also unnecessary to send a thread to execute ; all it needs is a Runnable . You'll get a big performance boost by applying these changes:

  1. Make the ExecutorService a static member, size it for the current processor, and send it a ThreadFactory so it doesn't keep the program running after main has finished. (It would probably be architecturally cleaner to send it as a parameter to the method rather than keeping it as a static field; I leave that as an exercise for the reader. ☺)

     private static final ExecutorService workerPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), new ThreadFactory() { public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setDaemon(true); return t; } }); 
  2. Make MatrixThread implement Runnable rather than inherit Thread . Threads are expensive to create; POJOs are very cheap. You can also make it static which makes the instances smaller (as non-static classes get an implicit reference to the enclosing object).

     private static class MatrixThread implements Runnable 
  3. From change (1), you can no longer awaitTermination to make sure all tasks are finished (as this worker pool). Instead, use the submit method which returns a Future<?> . Collect all the future objects in a list, and when you've submitted all the tasks, iterate over the list and call get for each object.

Your multiply method should now look something like this:

public Matrix multiply(Matrix multiplier) throws InterruptedException {
    Matrix result = new Matrix(dimension);
    List<Future<?>> futures = new ArrayList<Future<?>>();
    for(int currRow = 0; currRow < multiplier.dimension; currRow++) {
        for(int currCol = 0; currCol < multiplier.dimension; currCol++) {            
            Runnable worker = new MatrixThread(this, multiplier, currRow, currCol, result);
            futures.add(workerPool.submit(worker));
        }
    }
    for (Future<?> f : futures) {
        try {
            f.get();
        } catch (ExecutionException e){
            throw new RuntimeException(e); // shouldn't happen, but might do
        }
    }
    return result;
}

Will it be faster than the single-threaded version? Well, on my arguably crappy box the multithreaded version is slower for values of n < 1024.

This is just scratching the surface, though. The real problem is that you create a lot of MatrixThread instances - your memory consumption is O(n²) , which is a very bad sign . Moving the inner for loop into MatrixThread.run would improve performance by a factor of craploads (ideally, you don't create more tasks than you have worker threads).


Edit: As I have more pressing things to do, I couldn't resist optimizing this further. I came up with this (... horrendously ugly piece of code) that "only" creates O(n) jobs:

 public Matrix multiply(Matrix multiplier) throws InterruptedException {
     Matrix result = new Matrix(dimension);
     List<Future<?>> futures = new ArrayList<Future<?>>();
     for(int currRow = 0; currRow < multiplier.dimension; currRow++) {
         Runnable worker = new MatrixThread2(this, multiplier, currRow, result);
         futures.add(workerPool.submit(worker)); 
     }
     for (Future<?> f : futures) {
         try {
             f.get();
         } catch (ExecutionException e){
             throw new RuntimeException(e); // shouldn't happen, but might do
         }
     }
     return result;
 }


private static class MatrixThread2 implements Runnable
{
   private Matrix self, mul, result;
   private int row, col;      

   private MatrixThread2(Matrix a, Matrix b, int row, Matrix result)
   {         
      this.self = a;
      this.mul = b;
      this.row = row;
      this.result = result;
   }

   @Override
   public void run()
   {
      for(int col = 0; col < mul.dimension; col++) {
         int cellResult = 0;
         for (int i = 0; i < self.getMatrixDimension(); i++)
            cellResult += self.template[row][i] * mul.template[i][col];
         result.template[row][col] = cellResult;
      }
   }
}

It's still not great, but basically the multi-threaded version can compute anything you'll be patient enough to wait for, and it'll do it faster than the single-threaded version.

First of all, you should use a newFixedThreadPool of the size as many cores you have, on a quadcore you use 4. Second of all, don't create a new one for each matrix.

If you make the executorservice a static member variable I get almost consistently faster execution of the threaded version at a matrix size of 512.

Also, change MatrixThread to implement Runnable instead of extending Thread also speeds up execution to where the threaded is on my machine 2x as fast on 512

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM