简体   繁体   中英

How do I specialize a trait function?

I have a trait for BLAS functionality:

pub trait Blas {
    fn gemv<F>(&self, trans: Transpose,
               cols: usize, rows: usize, matrix: &[F], matrix_factor: F,
               vector: &[F], vector_inc: usize, vector_factor: F,
               result: &[F], result_inc: usize) -> Result<(), Error>;
    ...
}

Now I want to make a type which implements this trait:

pub struct CudaBlas {
    ...
}

impl Blas for CudaBlas {
    ...
}

The problem is that I need separate specialisations for gemv<f32> and gemv<f64> : each one should call a dedicated shared library function. Didn't succeed to express that without compiler complaints. How can I achieve that?

UPDATE:

I tried the method proposed by Jonas Tepe and it doesn't seem to work. Here is the purified example:

trait Trait<T> {
    fn func(&self, arg: T);
}

struct Struct {
    field: usize,
}

impl Trait<f32> for Struct {
    fn func(&self, arg: f32) {
        println!("32bits: {}", arg);    
    }
}

impl Trait<f64> for Struct {
    fn func(&self, arg: f64) {
        println!("64bits: {}", arg);
    }
}

struct Struct2<T> {
    field2: T,
}

// yes, I plan to use my CudaBlas inside some generic NeuralNet<T>
impl<T> Struct2<T> {
    fn func2(&self, arg: T) {
        let s = Struct{field: 1};
        s.func(arg);
    }
}

fn main() {
    let s32 = Struct2::<f32>{field2: 1f32};
    let s64 = Struct2::<f64>{field2: 2f64};
    s32.func2(1f32);
    s64.func2(1f64);
}

I get:

error: the trait Trait<T> is not implemented for the type Struct [E0277]

Making Struct to be generic doesn't solve the problem as well (the compiler complaints that func is not found for type Struct<T> ). Just amazed how restrictive the Rust generics are.

One solution would be to make your trait Blas generic with respect to the floating point type and then have two separate implementations of this trait for your CudaBlas struct :

pub trait Blas<F> {
    fn gemv(&self, trans: Transpose,
               cols: usize, rows: usize, matrix: &[F], matrix_factor: F,
               vector: &[F], vector_inc: usize, vector_factor: F,
               result: &[F], result_inc: usize) -> Result<(), Error>;
    ...
}

impl Blas<f32> for CudaBlas {
    fn gemv(&self, trans: Transpose,
            cols: usize, rows: usize, matrix: &[f32], matrix_factor: f32,
            vector: &[f32], vector_inc: usize, vector_factor: f32,
            result: &[f32], result_inc: usize) -> Result<(), Error> {
           // implement f32 specific functionality
     }
}

impl Blas<f64> for CudaBlas {
        fn gemv(&self, trans: Transpose,
                cols: usize, rows: usize, matrix: &[f64], matrix_factor: f64,
                vector: &[f64], vector_inc: usize, vector_factor: f64,
                result: &[f64], result_inc: usize) -> Result<(), Error> {
               // implement f64 specific functionality
         }

}

After that you can call the method gemv() on your CudaBlas with f32 or f64 every time with the desired type specific results.

All I needed is to add where CudaBlas: Blas<T> :

#![allow(dead_code, unused_variables)]

trait Blas<T> {
    fn gemv(&self, arg: T);
}

struct CudaBlas {
    field: usize,
}

impl Blas<f32> for CudaBlas {
    fn gemv(&self, arg: f32) {
        println!("f32");
    }
}

impl Blas<f64> for CudaBlas {
    fn gemv(&self, arg: f64) {
        println!("f64");
    }
}

struct NeuralNet<T> {
    field: T,
}

impl<T> NeuralNet<T> {
    fn process(&self, arg: T) where CudaBlas: Blas<T> {
        let cblas = CudaBlas{field:0};
        cblas.gemv(arg);
    }
}

fn main() {
    let nn = NeuralNet{field:0f64};
    nn.process(12f64);
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM