简体   繁体   English

Rust 返回迭代器,其项目具有特征绑定

[英]Rust returnig iterator with trait bound on its items

I am trying to write a common interface for different types of matrices that provides a way to mutably iterate their rows and modify them.我正在尝试为不同类型的矩阵编写一个通用接口,该接口提供了一种可变地迭代它们的行并修改它们的方法。 I have the following matrix types:我有以下矩阵类型:

struct NdArrayMatrix {
    matrix: Array2<f32>,
}

struct ByteMatrix<'a> {
    data: &'a mut [u8],
    rows: usize,
    cols: usize,
}

Where the first one is just a RAM-stored matrix, and the second is memory mapped, using the MMap library, but for convenience, I omit those details.第一个只是 RAM 存储的矩阵,第二个是 memory 映射,使用 MMap 库,但为方便起见,我省略了这些细节。 First, I made a trait to be able to modify both of them using the same interface:首先,我创建了一个 trait,以便能够使用相同的界面修改它们:

trait ReadWrite
{
    fn rw_read(&self, i: usize, j: usize) -> f32;
    fn rw_write(&mut self, i: usize, j: usize, val: f32);
}

Then, I've created a trait that produces a rayon::iter::IndexedParallelItertor from both of these:然后,我创建了一个从这两个生成rayon::iter::IndexedParallelItertor的特征:

trait Sliceable<'a>
{
    type Output: IndexedParallelIterator;

    fn rows_par_iter(&'a mut self ) -> Self::Output;
}

Up to this point works everything fine.到目前为止,一切正常。 But when I want to use these in a generic context, such as:但是当我想在通用上下文中使用这些时,例如:

fn<'a, T> slice_and_write(matrix: T)
where T: Sliceable<'a>
{
    T.rows_par_iter()
     .map(|mut row| {
          row.rw_write(...);
     })
     ...
}

I run into problems.我遇到了问题。 It is obvious, that row, in this case, doesn't implement ReadWrite so no surprise there.很明显,在这种情况下,该行没有实现ReadWrite ,所以这并不奇怪。 So what I tried to do, is to create an iterator trait based on IndexedParallelItertor :所以我试图做的是基于IndexedParallelItertor创建一个迭代器特征:

trait RwIterator: IndexedParallelIterator {
    type Item: ReadWrite;
}

and modify Sliceable :并修改Sliceable

trait Sliceable<'a>
{
    type Output: RwIterator;

    fn rows_par_iter(&'a mut self ) -> Self::Output;
}

Running this I get the error:运行这个我得到错误:

   |  row.rw_write(...);
   |      ^^^^^^^^ method not found in `<<T as Sliceable<'a>>::Output as ParallelIterator>::Item`

Which is, again, fairly obvious.这又是相当明显的。 I suspect that the map function does only require the trait bound ParallelIterator, hence can't take advantage of the trait RwIterator .我怀疑map function 只需要特征绑定 ParallelIterator ,因此不能利用特征RwIterator

My question is: Is there any way around this problem, or an alternate way for doing this?我的问题是:有什么办法可以解决这个问题,或者有另一种方法吗?

EDIT: Here is a minimal reproducible code example, only using one of the matrix structures.编辑:这是一个最小的可重现代码示例,仅使用其中一种矩阵结构。

use ndarray::Array2;
use rayon::prelude::*;
use ndarray::Axis;
use ndarray::parallel::Parallel;
use ndarray::Dim;
use ndarray::iter::AxisIterMut;
use rayon::iter::ParallelIterator;
use ndarray::ViewRepr;
use ndarray::ArrayBase;

struct NdArrayMatrix {
    matrix: Array2<f32>,
}

impl NdArrayMatrix {
    pub fn new() -> Self {
        let matrix = Array2::zeros((10, 10));
        
        Self {
            matrix,
        }
    }
}

trait ReadWrite
{
    fn rw_read(&self, i: usize, j: usize) -> f32;
    fn rw_write(&mut self, i: usize, j: usize, val: f32);
}

impl ReadWrite for NdArrayMatrix {
    fn rw_read(&self, i: usize, j: usize) -> f32 {
        self.matrix[[i, j]]
    }

    fn rw_write(&mut self, i: usize, j: usize, val: f32) {
        self.matrix[[i, j]] = val;
    }
}

impl ReadWrite for ArrayBase<ViewRepr<&mut f32>, Dim<[usize; 1]>> {
    fn rw_read(&self, i: usize, j: usize) -> f32 {
        self[j]
    }

    fn rw_write(&mut self, i: usize, j: usize, val: f32) {
        self[j] = val;
    }
}

trait RwIterator: IndexedParallelIterator {
    type Item: ReadWrite;
}

impl<'a> RwIterator for Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>> {
    type Item =  ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
}

trait Sliceable<'a>
{
    type Output: RwIterator;

    fn rows_par_iter(&'a mut self ) -> Self::Output;
}

impl<'a> Sliceable<'a> for NdArrayMatrix {
    type Output = Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>>;

    fn rows_par_iter(&'a mut self) -> Self::Output {
        self.matrix
            .axis_iter_mut(Axis(0))
            .into_par_iter()
    }
}

fn main() {
    let mut matrix: NdArrayMatrix = NdArrayMatrix::new();

    test(matrix);
}

fn test<'a, T> (matrix: T)
where T: Sliceable<'a> + ReadWrite
{
    matrix.rows_par_iter()
        .map(|mut row| {
            row.rw_write(0, 0, 0.0);
        }).count();
}

Your code is 90% there.你的代码在那里有 90%。

The problem you are facing is that RwIterator::Item is ReadWrite , but nowhere does your code constrain that RwIterator::Item has to be the same as the ParallelIterator::Item of the same object.您面临的问题是RwIterator::ItemReadWrite ,但您的代码没有任何地方限制RwIterator::Item必须与同一 object 的ParallelIterator::Item相同。

To fix this, you can annotate it manually:要解决此问题,您可以手动对其进行注释:

trait RwIterator: IndexedParallelIterator<Item = <Self as RwIterator>::Item> {
    type Item: ReadWrite;
}

impl<'a> RwIterator for Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>> {
    type Item = ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
}

With that, Rust now understands the connection.有了这个,Rust 现在理解了连接。

Some other minor adjustments were necessary as well.其他一些小的调整也是必要的。 Here is a version that compiles:这是一个编译的版本:

use ndarray::iter::AxisIterMut;
use ndarray::parallel::Parallel;
use ndarray::Array2;
use ndarray::ArrayBase;
use ndarray::Axis;
use ndarray::Dim;
use ndarray::ViewRepr;
use rayon::iter::ParallelIterator;
use rayon::prelude::*;

struct NdArrayMatrix {
    matrix: Array2<f32>,
}

impl NdArrayMatrix {
    pub fn new() -> Self {
        let matrix = Array2::zeros((10, 10));

        Self { matrix }
    }
}

trait ReadWrite {
    fn rw_read(&self, i: usize, j: usize) -> f32;
    fn rw_write(&mut self, i: usize, j: usize, val: f32);
}

impl ReadWrite for NdArrayMatrix {
    fn rw_read(&self, i: usize, j: usize) -> f32 {
        self.matrix[[i, j]]
    }

    fn rw_write(&mut self, i: usize, j: usize, val: f32) {
        self.matrix[[i, j]] = val;
    }
}

impl ReadWrite for ArrayBase<ViewRepr<&mut f32>, Dim<[usize; 1]>> {
    fn rw_read(&self, i: usize, j: usize) -> f32 {
        self[j]
    }

    fn rw_write(&mut self, i: usize, j: usize, val: f32) {
        self[j] = val;
    }
}

trait RwIterator: IndexedParallelIterator<Item = <Self as RwIterator>::Item> {
    type Item: ReadWrite;
}

impl<'a> RwIterator for Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>> {
    type Item = ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
}

trait Sliceable<'a> {
    type Output: RwIterator;

    fn rows_par_iter(&'a mut self) -> Self::Output;
}

impl<'a> Sliceable<'a> for NdArrayMatrix {
    type Output = Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>>;

    fn rows_par_iter(&'a mut self) -> Self::Output {
        self.matrix.axis_iter_mut(Axis(0)).into_par_iter()
    }
}

fn main() {
    let matrix: NdArrayMatrix = NdArrayMatrix::new();

    test(matrix);
}

fn test<T>(mut matrix: T)
where
    for<'a> T: Sliceable<'a>,
{
    matrix
        .rows_par_iter()
        .map(|mut row| {
            row.rw_write(0, 0, 0.0);
        })
        .count();
}

Little Excurse小旅行

All of this is only necessary because associated types cannot be annotated with trait bounds yet.所有这些都是必要的,因为关联的类型还不能用特征边界注释。

This might change if RFC 2289 gets stabilized at some point.如果RFC 2289在某个时候稳定下来,这可能会改变。

Then, you might be able to completely delete the RwIterator trait and specify it like this instead:然后,您也许可以完全删除RwIterator特征并像这样指定它:

type Output: ParallelIterator<Item: ReadWrite>;

No guarantees on that one, though.但是,对此没有任何保证。 I didn't get it to work with the nightly compiler yet.我还没有让它与nightly编译器一起工作。

You can already kind of emulate that behaviour, with some boilerplate code:您已经可以使用一些样板代码来模拟这种行为:

use ndarray::iter::AxisIterMut;
use ndarray::parallel::Parallel;
use ndarray::Array2;
use ndarray::ArrayBase;
use ndarray::Axis;
use ndarray::Dim;
use ndarray::ViewRepr;
use rayon::iter::ParallelIterator;
use rayon::prelude::*;

struct NdArrayMatrix {
    matrix: Array2<f32>,
}

impl NdArrayMatrix {
    pub fn new() -> Self {
        let matrix = Array2::zeros((10, 10));

        Self { matrix }
    }
}

trait ReadWrite {
    fn rw_read(&self, i: usize, j: usize) -> f32;
    fn rw_write(&mut self, i: usize, j: usize, val: f32);
}

impl ReadWrite for NdArrayMatrix {
    fn rw_read(&self, i: usize, j: usize) -> f32 {
        self.matrix[[i, j]]
    }

    fn rw_write(&mut self, i: usize, j: usize, val: f32) {
        self.matrix[[i, j]] = val;
    }
}

impl ReadWrite for ArrayBase<ViewRepr<&mut f32>, Dim<[usize; 1]>> {
    fn rw_read(&self, i: usize, j: usize) -> f32 {
        self[j]
    }

    fn rw_write(&mut self, i: usize, j: usize, val: f32) {
        self[j] = val;
    }
}

trait Sliceable<'a> {
    type Item: ReadWrite;
    type Output: ParallelIterator<Item = Self::Item>;

    fn rows_par_iter(&'a mut self) -> Self::Output;
}

impl<'a> Sliceable<'a> for NdArrayMatrix {
    // Sadly needs explicit type annotation for `Item`
    type Item = ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
    type Output = Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>>;

    fn rows_par_iter(&'a mut self) -> Self::Output {
        self.matrix.axis_iter_mut(Axis(0)).into_par_iter()
    }
}

fn main() {
    let matrix: NdArrayMatrix = NdArrayMatrix::new();

    test(matrix);
}

fn test<T>(mut matrix: T)
where
    for<'a> T: Sliceable<'a>,
{
    matrix
        .rows_par_iter()
        .map(|mut row| {
            row.rw_write(0, 0, 0.0);
        })
        .count();
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM