简体   繁体   中英

Fast conversion of Java array to NumPy array (Py4J)

There are some nice examples how to convert NumPy array to Java array, but not vice versa - how to convert data from Java object back to NumPy array. I have a Python script like this:

    from py4j.java_gateway import JavaGateway
    gateway = JavaGateway()            # connect to the JVM
    my_java = gateway.jvm.JavaClass();  # my Java object
    ....
    int_array=my_java.doSomething(int_array); # do something

    my_numpy=np.zeros((size_y,size_x));
    for jj in range(size_y):
        for ii in range(size_x):
            my_numpy[jj,ii]=int_array[jj][ii];

my_numpy is the Numpy array, int_array is the Java array of integers - int[ ][ ] kind of array. Initialized in Python script as:

    int_class=gateway.jvm.int       # make int class
    double_class=gateway.jvm.double # make double class

    int_array = gateway.new_array(int_class,size_y,size_x)
    double_array = gateway.new_array(double_class,size_y,size_x)

Although, it works as it is, it is not the fastest way and works rather slowly - for ~1000x1000 array, the conversion took more than 5 minutes.

Is there any way how to make this with reasonable time?

If I try:

    test=np.array(int_array)

I get:

    ValueError: invalid __array_struct__

I had a similar problem and found a solution that is around 220 times faster for the case I tested on: For transferring a 1628x120 array of short integers from Java to Numpy, the runtime was reduced from 11 seconds to 0.05 seconds. Thanks to this related StackOverflow question , I started looking into py4j byte arrays , and it turns out that py4j efficiently converts Java byte arrays to Python bytes objects and vice versa (passing by value, not by reference). It's a fairly roundabout way of doing things, but not too difficult.

Thus, if you want to transfer an integer array intArray with dimensions iMax x jMax (and for the sake of the example, I assume that these are all stored as instance variables in your object), you can first write a Java function to convert it to a byte[] like so:

public byte[] getByteArray() {
    // Set up a ByteBuffer called intBuffer
    ByteBuffer intBuffer = ByteBuffer.allocate(4*iMax*jMax); // 4 bytes in an int
    intBuffer.order(ByteOrder.LITTLE_ENDIAN); // Java's default is big-endian

    // Copy ints from intArray into intBuffer as bytes
    for (int i = 0; i < iMax; i++) {
        for (int j = 0; j < jMax; j++){
            intBuffer.putInt(intArray[i][j]);
        }
    }

    // Convert the ByteBuffer to a byte array and return it
    byte[] byteArray = intBuffer.array();
    return byteArray;
}

Then, you can write Python 3 code to receive the byte array and convert it to a numpy array of the correct shape:

byteArray = gateway.entry_point.getByteArray()
intArray = np.frombuffer(byteArray, dtype=np.int32)
intArray = intArray.reshape((iMax, jMax))

I've had a similar issue, just trying to plot spectral vectors (Java arrays) I got from the Java side via py4j. Here, the conversion from the Java Array to a Python list is achieved by the list() function. This might give some clues as how to use it to fill NumPy arrays ...

vectors = space.getVectorsAsArray(); # Java array (MxN)
wvl = space.getAverageWavelengths(); # Java array (N)

wavelengths = list(wvl)

import matplotlib.pyplot as mp
mp.hold
for i, dataset in enumerate(vectors):
    mp.plot(wavelengths, list(dataset))

Whether this is faster than the nested for loops you used I cannot say, but it also does the trick:

import numpy
from numpy  import array
x = array(wavelengths)
v = array(list(vectors))

mp.plot(x, numpy.rot90(v))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM