简体   繁体   中英

Lifetime parameter problem in custom iterator over mutable references

I'd like to implement a custom Iterator like below, but cannot solve reference problem.

use itertools::Product;
use std::ops::Range;
struct Iter2DMut<'a, T: 'a> {
    data: &'a mut [T],
    shape: (usize, usize),
    idx_iter: Product<Range<usize>, Range<usize>>,
}

impl<'a, T: 'a> Iterator for Iter2DMut<'a, T> {
    type Item = &'a mut T;
    fn next(&mut self) -> Option<Self::Item> {
        if let Some((i, j)) = self.idx_iter.next() {
            Some(&mut self.data[i + self.shape.0 * j])
        } else {
            None
        }
    }
}

and get the following error message.

error[E0495]: cannot infer an appropriate lifetime for lifetime parameter in function call due to conflicting requirements
  --> src/main.rs:13:23
   |
13 |             Some(&mut self.data[i + self.shape.0 * j])
   |                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   |

Based on the author's clarification in the comments, I'm assuming that the goal here is to iterate over a rectangular submatrix of a matrix. For example, given a matrix

100  200  300  400  500  600
110  210  310  410  510  610
120  220  320  420  520  620
130  230  330  430  530  630

as represented by a slice in row-major order

[100, 200, 300, 400, 500, 600, 110, ..., 530, 630]

we want to iterate over a submatrix such as

210  310  410  510
220  320  420  520

again in row-major order, so the elements we would get would be, in order,

210, 310, 410, 510, 220, 320, 420, 520

In this situation, it is possible to solve this problem relatively efficiently using safe Rust. The trick is to use the split_at_mut method of the slice in the data field of Iter2DMut , in order to peel off one mutable reference at a time as needed. As the iteration proceeds, the data field is updated to a smaller and smaller slice, so that it no longer encompasses elements which have already been iterated over; this is necessary, because at any given iteration Rust would not allow us to produce a mutable reference to an element while also retaining a mutable slice containing that element. By updating the slice, we can ensure that it is always disjoint from the mutable references which have been produced by all previous calls to next() , satisfying the Rust borrow checker. Here is how this can be done:

use itertools::{Itertools, Product};
use std::ops::Range;
use std::mem;

struct Iter2DMut<'a, T: 'a> {
    data: &'a mut [T],
    full_shape: (usize, usize),
    sub_shape: (usize, usize),
    idx_iter: Product<Range<usize>, Range<usize>>,
}

impl<'a, T> Iter2DMut<'a, T> {
    fn new(
        data: &'a mut [T],
        full_shape: (usize, usize),
        sub_shape: (usize, usize),
        offset: (usize, usize),
    ) -> Self {
        assert!(full_shape.0 * full_shape.1 == data.len());
        assert!(offset.0 + sub_shape.0 <= full_shape.0);
        assert!(offset.1 + sub_shape.1 <= full_shape.1);
        Iter2DMut {
            data: &mut data[offset.0 * full_shape.1 + offset.1 ..],
            full_shape,
            sub_shape,
            idx_iter: (0..sub_shape.0).cartesian_product(0..sub_shape.1)
        }
    }
}
impl<'a, T: 'a> Iterator for Iter2DMut<'a, T> {
    type Item = &'a mut T;

    fn next(&mut self) -> Option<Self::Item> {
        if let Some((_, j)) = self.idx_iter.next() {
            let mut data: &'a mut [T] = &mut [];
            mem::swap(&mut self.data, &mut data);
            let (first, rest) = data.split_at_mut(1);
            data = rest;
            if j == self.sub_shape.1 - 1 {
                let n_skip = self.full_shape.1 - self.sub_shape.1;
                let (_, rest) = data.split_at_mut(n_skip);
                data = rest;
            }
            self.data = data;
            Some(&mut first[0])
        } else {
            None
        }
    }
}
fn main() {
    let mut v: Vec<usize> = vec![
        100, 200, 300, 400, 500, 600,
        110, 210, 310, 410, 510, 610,
        120, 220, 320, 420, 520, 620,
        130, 230, 330, 430, 530, 630,
    ];
    for x in Iter2DMut::new(&mut v, (4, 6), (2, 4), (1, 1)) {
        println!("{}", x);
    }
}

There's one other trick here worth noting: we use mem::swap to move out the data field from the Iter2DMut in order to call split_at_mut on it. We temporarily swap in a dummy value &mut [] ; this is necessary since Rust won't allow us to move a value out of a (mutably) borrowed struct (even temporarily) without putting something back in at the same time. On the other hand, if we hadn't tried to move data out but had simply called split_at_mut directly, as in self.data.split_at_mut(1) , it would have failed the borrow checker, because then we would have been borrowing self.data which only lives as long as the the &mut self reference input into the next method, which is not necessarily long as the 'a lifetime that we need it to be.

Edit: This is a more general explanation of the problem of creating an iterator over mutable references. Brent's answer shows how to use a function from std to take care of the unsafe pointer manipulation for you, to solve this specific problem.


Iterating over mutable references requires unsafe code somewhere . To see why, consider a simpler example:

struct MyIterMut<'a, T: 'a> {
    data: &'a mut [T],
    index: usize,
}

impl<'a, T: 'a> Iterator for MyIterMut<'a, T> {
    type Item = &'a mut T;
    fn next(&mut self) -> Option<Self::Item> {
        unimplemented!()
    }
}

fn main() {
    let mut data = vec![1, 2, 3, 4];

    let a;
    let b;
    {
        let mut iter = MyIterMut { data: &mut data, index: 0 };
        a = iter.next();
        b = iter.next();
    }

    // a and b  are usable after the iterator is dropped, as long as data is still around
    println!("{:?}, {:?}", a, b);
}

A user of this iterator is allowed to use values from the iterator after it is dropped, as long as the original data is still live. This is expressed in the type of next which, with explicit lifetimes added, is:

fn next<'n>(&'n mut self) -> Option<&'a mut T>

There is no relationship between 'n and 'a , so the code that uses the iterator is free to use them without constraint. This is what you want.

Suppose we implement next() like this:

fn next(&mut self) -> Option<&'a mut T> {
    Some(&mut self.data[0])
}

This implementation is Bad and causes the same error that you are seeing in your code. If the compiler allowed it, the main function above would have two variables, a and b , which both contain mutable references to the same data. This is Undefined Behaviour, and the borrow checker prevents it from happening.

The way it is prevented is by noting that you are borrowing from self , which has a lifetime that is unrelated to the lifetime of the data. The compiler has no way of knowing if next will be called multiple times or what the caller will do with the data. It only knows that there isn't enough information to decide if it's safe.

But, you may argue, you don't need to borrow the whole of self ; you only need to borrow that single item from the slice. Unfortunately, when you borrow a piece of a struct, you borrow the entire struct. There is no way to express in the types that this call to next() will borrow index 0, and the next will borrow index 1 etc.

Given that you know that your implementation is only going to borrow each index once, you can use raw pointers and just tell the borrow-checker that you know what you are doing:

impl<'a, T: 'a> Iterator for MyIterMut<'a, T> {
    type Item = &'a mut T;
    fn next(&mut self) -> Option<Self::Item> {
        if self.index < self.data.len() {
            let index = self.index;
            self.index += 1;
            let ptr = self.data.as_mut_ptr();
            Some(unsafe { &mut *ptr.add(index) })
        } else {
            None
        }
    }
}

Since the iterator takes a &mut reference to data , it is not possible to construct multiple instances of it. If it were possible then there would still be a possibility of Undefined Behaviour, but the Rust borrow-checker takes care of this for us.


Whenever using unsafe code, you need to be extremely careful of how you enforce any invariants that you have assumed.

In your case, you will also need to ensure that it is not possible to create an invalid pointer due to the shape not matching the size of the data. You should probably panic! if that happens, which is always preferable to Undefined Behaviour.

I hope the length of this answer communicates that you shouldn't go into this lightly. Always prefer to use safe functions from std or popular third party crates if they are available; your unsafe code will not receive the same level of peer review and testing that the Rust standard library gets.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM