[英]Fast conversion of Java array to NumPy array (Py4J)
There are some nice examples how to convert NumPy array to Java array, but not vice versa - how to convert data from Java object back to NumPy array. 有一些很好的例子如何将NumPy数组转换为Java数组,但反之亦然 - 如何将数据从Java对象转换回NumPy数组。 I have a Python script like this:
我有一个像这样的Python脚本:
from py4j.java_gateway import JavaGateway
gateway = JavaGateway() # connect to the JVM
my_java = gateway.jvm.JavaClass(); # my Java object
....
int_array=my_java.doSomething(int_array); # do something
my_numpy=np.zeros((size_y,size_x));
for jj in range(size_y):
for ii in range(size_x):
my_numpy[jj,ii]=int_array[jj][ii];
my_numpy
is the Numpy array, int_array
is the Java array of integers - int[ ][ ]
kind of array. my_numpy
是Numpy数组, int_array
是整数的Java数组 - int[ ][ ]
种数组。 Initialized in Python script as: 在Python脚本中初始化为:
int_class=gateway.jvm.int # make int class
double_class=gateway.jvm.double # make double class
int_array = gateway.new_array(int_class,size_y,size_x)
double_array = gateway.new_array(double_class,size_y,size_x)
Although, it works as it is, it is not the fastest way and works rather slowly - for ~1000x1000 array, the conversion took more than 5 minutes. 虽然它可以正常工作,但它不是最快的方式而且工作速度相当慢 - 对于~1000x1000阵列,转换时间超过5分钟。
Is there any way how to make this with reasonable time? 有没有办法如何在合理的时间内完成这个?
If I try: 如果我尝试:
test=np.array(int_array)
I get: 我明白了:
ValueError: invalid __array_struct__
I had a similar problem and found a solution that is around 220 times faster for the case I tested on: For transferring a 1628x120 array of short integers from Java to Numpy, the runtime was reduced from 11 seconds to 0.05 seconds. 我遇到了类似的问题,发现一个解决方案比我测试的情况快了大约220倍:为了将一个1628x120的短整数数组从Java转移到Numpy,运行时间从11秒减少到0.05秒。 Thanks to this related StackOverflow question , I started looking into py4j byte arrays , and it turns out that py4j efficiently converts Java byte arrays to Python bytes objects and vice versa (passing by value, not by reference).
感谢这个相关的StackOverflow问题 ,我开始研究py4j字节数组 ,结果发现py4j有效地将Java字节数组转换为Python字节对象,反之亦然(通过值传递,而不是通过引用)。 It's a fairly roundabout way of doing things, but not too difficult.
这是一种相当迂回的做事方式,但并不太难。
Thus, if you want to transfer an integer array intArray
with dimensions iMax
x jMax
(and for the sake of the example, I assume that these are all stored as instance variables in your object), you can first write a Java function to convert it to a byte[] like so: 因此,如果你想传递一个带有维度
iMax
x jMax
的整数数组intArray
(为了这个例子,我假设它们都作为实例变量存储在你的对象中),你可以先编写一个Java函数来转换它。像这样的字节[]:
public byte[] getByteArray() {
// Set up a ByteBuffer called intBuffer
ByteBuffer intBuffer = ByteBuffer.allocate(4*iMax*jMax); // 4 bytes in an int
intBuffer.order(ByteOrder.LITTLE_ENDIAN); // Java's default is big-endian
// Copy ints from intArray into intBuffer as bytes
for (int i = 0; i < iMax; i++) {
for (int j = 0; j < jMax; j++){
intBuffer.putInt(intArray[i][j]);
}
}
// Convert the ByteBuffer to a byte array and return it
byte[] byteArray = intBuffer.array();
return byteArray;
}
Then, you can write Python 3 code to receive the byte array and convert it to a numpy array of the correct shape: 然后,您可以编写Python 3代码来接收字节数组并将其转换为正确形状的numpy数组:
byteArray = gateway.entry_point.getByteArray()
intArray = np.frombuffer(byteArray, dtype=np.int32)
intArray = intArray.reshape((iMax, jMax))
I've had a similar issue, just trying to plot spectral vectors (Java arrays) I got from the Java side via py4j. 我有一个类似的问题,只是试图绘制我从Java端通过py4j获得的光谱矢量(Java数组)。 Here, the conversion from the Java Array to a Python list is achieved by the list() function.
这里,通过list()函数实现从Java Array到Python列表的转换。 This might give some clues as how to use it to fill NumPy arrays ...
这可能会提供一些线索,如何使用它来填充NumPy数组......
vectors = space.getVectorsAsArray(); # Java array (MxN)
wvl = space.getAverageWavelengths(); # Java array (N)
wavelengths = list(wvl)
import matplotlib.pyplot as mp
mp.hold
for i, dataset in enumerate(vectors):
mp.plot(wavelengths, list(dataset))
Whether this is faster than the nested for loops you used I cannot say, but it also does the trick: 这是否比你使用的嵌套for循环更快我不能说,但它也可以解决问题:
import numpy
from numpy import array
x = array(wavelengths)
v = array(list(vectors))
mp.plot(x, numpy.rot90(v))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.