簡體   English   中英

使用 PyO3 從 Python 將列表列表作為參數傳遞給 Rust

[英]Pass list of lists as argument to Rust from Python using PyO3

我正在嘗試使用 Py03 將列表列表從 Python 傳遞到 Rust。 我試圖將其傳遞給的 function 具有以下簽名:

pub fn k_nearest_neighbours(k: usize, x: &[[f32; 2]], y: &[[f32; 3]]) -> Vec<Option<f32>> 

我正在為預先存在的庫編寫綁定,因此我無法更改原始代碼。 我目前的做事方式是這樣的:

// This is example code == DOES NOT WORK
#[pyfunction] // make a new function within a new library with pyfunction macro
fn k_nearest_neighbours(k: usize, x: Vec<Vec<f32>>, y: Vec<f32>) -> Vec<Option<f32>> {
    // reformat input where necessary
    let x_slice = x.as_slice();
    // return original lib's function return
    classification::k_nearest_neighbours(k, x_slice, y)
}

x.as_slice() function幾乎可以滿足我的需要,它給了我一片向量&[Vec<f32>] ,而不是一片片&[[f32; 3]] &[[f32; 3]]

我希望能夠運行這個 Python 代碼:

from rust_code import k_nearest_neighbours as knn  # this is the Rust function compiled with PyO3

X = [[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [0.06, 7.0]]
train = [
        [0.0, 0.0, 0.0],
        [0.5, 0.5, 0.0],
        [3.0, 3.0, 1.0],
        [4.0, 3.0, 1.0],
    ]

k = 2
y_true = [0, 1, 1, 1]
y_test = knn(k, X, train)
assert(y_true == y_test)

查看k_nearest_neighbours的簽名表明它期望[f32; 2] [f32; 2][f32; 3] [f32; 3]是 arrays,而不是切片(例如&[f32]

Arrays 在編譯時具有靜態已知大小,而切片是動態大小的。 向量也是如此,您無法控制示例中內部向量的長度。 因此,您最終會得到從輸入向量到預期 arrays 的錯誤轉換。

您可以使用TryFrom從切片轉換為數組,即:

use std::convert::TryFrom;
fn main() {
    let x = vec![vec![3.5, 3.4, 3.6]];
    let x: Result<Vec<[f32; 3]>, _> = x.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>, _>>();
}

綜上所述,您的 function 將需要在輸入不正確時返回錯誤,並且您需要使用 arrays 創建一個新向量,您可以將其傳遞給您的 ZC1C425268E68385D1AB5074C17A94F14

#[pyfunction] // make a new function within a new library with pyfunction macro
fn k_nearest_neighbours(k: usize, x: Vec<Vec<f32>>, y: Vec<f32>) -> PyResult<Vec<Option<f32>>> {
    let x = x.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>, _>>();
    let y = y.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>, _>>();
    // Error handling is missing here, you'll need to look into PyO3's documentation for that
    ...
    // return original lib's function return
    Ok(classification::k_nearest_neighbours(k, &x, &y))
}

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM