簡體   English   中英

將Java數組快速轉換為NumPy數組(Py4J)

[英]Fast conversion of Java array to NumPy array (Py4J)

有一些很好的例子如何將NumPy數組轉換為Java數組,但反之亦然 - 如何將數據從Java對象轉換回NumPy數組。 我有一個像這樣的Python腳本:

    from py4j.java_gateway import JavaGateway
    gateway = JavaGateway()            # connect to the JVM
    my_java = gateway.jvm.JavaClass();  # my Java object
    ....
    int_array=my_java.doSomething(int_array); # do something

    my_numpy=np.zeros((size_y,size_x));
    for jj in range(size_y):
        for ii in range(size_x):
            my_numpy[jj,ii]=int_array[jj][ii];

my_numpy是Numpy數組, int_array是整數的Java數組 - int[ ][ ]種數組。 在Python腳本中初始化為:

    int_class=gateway.jvm.int       # make int class
    double_class=gateway.jvm.double # make double class

    int_array = gateway.new_array(int_class,size_y,size_x)
    double_array = gateway.new_array(double_class,size_y,size_x)

雖然它可以正常工作,但它不是最快的方式而且工作速度相當慢 - 對於~1000x1000陣列,轉換時間超過5分鍾。

有沒有辦法如何在合理的時間內完成這個?

如果我嘗試:

    test=np.array(int_array)

我明白了:

    ValueError: invalid __array_struct__

我遇到了類似的問題,發現一個解決方案比我測試的情況快了大約220倍:為了將一個1628x120的短整數數組從Java轉移到Numpy,運行時間從11秒減少到0.05秒。 感謝這個相關的StackOverflow問題 ,我開始研究py4j字節數組 ,結果發現py4j有效地將Java字節數組轉換為Python字節對象,反之亦然(通過值傳遞,而不是通過引用)。 這是一種相當迂回的做事方式,但並不太難。

因此,如果你想傳遞一個帶有維度iMax x jMax的整數數組intArray (為了這個例子,我假設它們都作為實例變量存儲在你的對象中),你可以先編寫一個Java函數來轉換它。像這樣的字節[]:

public byte[] getByteArray() {
    // Set up a ByteBuffer called intBuffer
    ByteBuffer intBuffer = ByteBuffer.allocate(4*iMax*jMax); // 4 bytes in an int
    intBuffer.order(ByteOrder.LITTLE_ENDIAN); // Java's default is big-endian

    // Copy ints from intArray into intBuffer as bytes
    for (int i = 0; i < iMax; i++) {
        for (int j = 0; j < jMax; j++){
            intBuffer.putInt(intArray[i][j]);
        }
    }

    // Convert the ByteBuffer to a byte array and return it
    byte[] byteArray = intBuffer.array();
    return byteArray;
}

然后,您可以編寫Python 3代碼來接收字節數組並將其轉換為正確形狀的numpy數組:

byteArray = gateway.entry_point.getByteArray()
intArray = np.frombuffer(byteArray, dtype=np.int32)
intArray = intArray.reshape((iMax, jMax))

我有一個類似的問題,只是試圖繪制我從Java端通過py4j獲得的光譜矢量(Java數組)。 這里,通過list()函數實現從Java Array到Python列表的轉換。 這可能會提供一些線索,如何使用它來填充NumPy數組......

vectors = space.getVectorsAsArray(); # Java array (MxN)
wvl = space.getAverageWavelengths(); # Java array (N)

wavelengths = list(wvl)

import matplotlib.pyplot as mp
mp.hold
for i, dataset in enumerate(vectors):
    mp.plot(wavelengths, list(dataset))

這是否比你使用的嵌套for循環更快我不能說,但它也可以解決問題:

import numpy
from numpy  import array
x = array(wavelengths)
v = array(list(vectors))

mp.plot(x, numpy.rot90(v))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM