简体   繁体   中英

How can I pass an &mut ndarray::Array to a function and perform element-wise arithmetic with it?

Motivation

I am trying to create my first real deal program in Rust for a school project (not a requirement..I just have been fascinated by Rust and decided that I'm going to take the plunge).

The project is a simple simulation of a robot's decisions based on some sensor data, some probabilities, prediction of future rewards, and some other stuff. The program consists of a main loop where lots of math takes place at each time step for some time horizon into the future. The data that gets carried to each subsequent time step is represented by a matrix Y that consists of two columns of linear coefficients (which are modified at each time step) of a set of linear constraints (where more constraints/rows of coefficients are added to the set at each time step).

Since the program will require lots of element-wise matrix operations and I'm well experienced in NumPy, the ndarray crate seemed like a perfect fit for the job. My thought process for the program was to make a mutable 2D array for Y that would get modified with each loop iteration, rather than allocating a new array every time. It has since dawned on me that the number of rows will grow an unknown amount with each iteration as well, so maybe this approach wasn't the greatest idea, but my question on the error I'm getting stands regardless.

Question

My question is this: if I want to modify an array at each iteration of a loop by passing a reference to the array into several functions that will modify its data, how can I also use the same array in basic element-wise arithmetic operations?

Here is a bare-bones example of my code to demonstrate:

extern crate ndarray;

use ndarray::prelude::*;

fn main() {
    let pz = array![[0.7, 0.3], [0.3, 0.7]]; // measurement probabilities

    let mut Y = Array2::<f64>::zeros((1, 2));

    for i in 1..10 {
        do_some_maths(&mut Y, pz);
        // other functions that will modify Y
    }
    
    println!("Result: {}", Y);
}

fn do_some_maths(Y: &mut Array2<f64>, pz: Array2<f64>) {

    let Yp = Y * pz.slice(s![.., 0]);  // <-- this is the problem

    // do lots of matrix math with Yp
    // ...
    // then modify Y's data using Yp (hence Y needs to be &mut)
}

Which gives the following compiling error:

error[E0369]: binary operation `*` cannot be applied to type `&mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>`
  --> src/main2.rs:21:16
   |
21 |     let Yp = Y * pz.slice(s![.., 0]);  // <-- this is the problem
   |              - ^ ------------------- ndarray::ArrayBase<ndarray::ViewRepr<&f64>, _>
   |              |
   |              &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>
   |
   = note: an implementation of `std::ops::Mul` might be missing for `&mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>`

I have spent many hours trying to understand

  1. what is the correct approach to my use case, and
  2. why the code I have written doesn't work.

I read several questions on this site that were somewhat related, but none of them really went into the case of dealing with an Array reference as a function parameter and performing a binary operation on it.

I've studied hard the first 5 chapters of the Rust book and dived deep into the documentation of ndarray , and I still can't find answers. ndarray 's documentation of ArrayBase contains the following explanation, which I don't fully understand:

Binary Operations on Two Arrays

Let A be an array or view of any kind. Let B be an array with owned storage (either Array or ArcArray). Let C be an array with mutable data (either Array, ArcArray or ArrayViewMut). The following combinations of operands are supported for an arbitrary binary operator denoted by @ (it can be +, -, *, / and so on).

  • &A @ &A which produces a new Array
  • B @ A which consumes B, updates it with the result, and returns it
  • B @ &A which consumes B, updates it with the result, and returns it
  • C @= &A which performs an arithmetic operation in place

Given this description, and searching through the many trait implementations for Add , Mul , etc., it seems to me that a mutable ndarray::Array cannot be an operand in a binary operation, except in the case of compound assignment.

Is that true, or am I missing something here? I don't want to simply memorize this little tidbit and move on; I really want to understand what is actually going on here, and where my understanding is lacking. Please help me to wrap my C++/Python trained brain around this. :)

You've answered your own question: the multiplication you are trying to perform is &C @ B , which isn't one of the four supported by ndarray . Also, you are passing pz as value to a function. It is consumed first round of the loop and isn't available any more for the rest. So that won't compile either.

This works:

extern crate ndarray;
use ndarray::prelude::*;

fn main() {
    let pz = array![[0.7, 0.3], [0.3, 0.7]];
    let mut y = Array2::<f64>::zeros((1, 2));

    for _ in 1..10 {
        do_some_maths(&mut y, &pz);
    }

    println!("Result: {}", y);
}

fn do_some_maths(y: &mut Array2<f64>, pz: &Array2<f64>) {
    *y *= &pz.slice(s![.., 0]);
}

A mutable reference is 'more powerful' than an immutable one, and you can always get a mutable reference to act as an immutable one, so that is not a problem.

As edwardw points out, you probably do not want to consume your array pz on every loop (and the compiler won't let you do that anyway). Indeed, if you think about the signature of your do_some_maths function, what you have is:

  • A mutable array that you want to modify
  • An immutable one that you use in addition

So it makes sense to have the signature being:

fn do_some_maths(y: &mut Array2<f64>, pz: &Array2<f64>) {
   ...
}

Now, the ndarray crate lets you:

  • Modify values in-place OR
  • create new ones for your operations

Generally, it is quite sensible about its inputs, taking references whenever possible as to not consume your input arrays. That means a good deal of (de)referencing might be needed, which you should feel free to use. In numpy, pretty much everything is a reference all the time, so you don't have to worry about it, but the logic is the same.

If you want to create Yp from Y , you can do that by allocating a new Yp value:

fn do_some_maths(y: &mut Array2<f64>, pz: &Array2<f64>) {
    // yp is a new Array2<f64>
    let yp: Array2<f64> = y * pz;
    // We may want to modify `y` now
    y.scaled_add(-2.3, yp);
    y *= pz;
}

The various operations made here are:

  • &Array2 * &Array2 -> Array2
  • scaled_add(self: &mut Array2, f64, &Array2) -> (), modifying the array in-place
  • in-place scalar operation &mut Array2 *= &Array2

In general, try to use references (mutable or not) as much as possible, except if you know that the input should be consumed.

To clarify the parallel with numpy: numpy arrays are essentially all references. Rust gives you the granularity to either pass in values directly (which are therefore consumed - think of them as use-once, then they are destroyed) or references (either mutable or immutable depending on whether you need to mutate them). Numpy uses essentially mutable references everywhere (except if you explicitly switch the WRITEABLE flag).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM