简体   繁体   中英

rust implement trait for multiple types

I am trying to implement a traits for multiple types at once, the only way I found to avoid code duplication is to transform all types to one common struct and implement the trait for that struct as below.

trait Increment {
    fn increment(&self) -> Option<String>;
}

struct NumberWrapper {
    number: String,
}

impl Increment for NumberWrapper {
    fn increment(&self) -> Option<String> {
        let num: Result<u64, _> = self.number.parse();
        match num {
            Err(_) => None,
            Ok(x) => Some((x + 1).to_string())
        }
    }
}

impl<T> From<T> for NumberWrapper where T: ToString {
    fn from(input: T) -> NumberWrapper {
        NumberWrapper { number: input.to_string() }
    }
}

fn main() {
    let number_u8: u8 = 10;
    println!("number_u8 is: {}", NumberWrapper::from(number_u8).increment().unwrap());
    let number_u32: u16 = 10;
    println!("number_u16 is: {}", NumberWrapper::from(number_u32).increment().unwrap());
    let number_u32: u32 = 10;
    println!("number_u32 is: {}", NumberWrapper::from(number_u32).increment().unwrap());
    let number_u64: u64 = 10;
    println!("number_u64 is: {}", NumberWrapper::from(number_u64).increment().unwrap());
}

Is there any other way to do the same?

Two ways to do this a little more elegantly come to my mind. First, I'm guessing you'd rather have your trait look something like this:

trait Increment {
    // It would probably be better to take `self` by value if you
    // just want this for numeric types which are cheaply copied,
    // but I'll leave it for generality.
    fn increment(&self) -> Option<Self> where Self: Sized;
}

I will assume this going forward (but please correct me).

The first way uses a pretty simple macro :

macro_rules! impl_increment {
    ($($t:ty),*) => {
        $(
            impl Increment for $t {
                fn increment(&self) -> Option<Self> {
                    self.checked_add(1)
                }
            }
        )*
    }
}

It only matches against one rule which reads any number of types separated by commas and implements the increment method for that type based on the checked_add method numeric primitives have in Rust. You can call this just like this:

// This will create an impl block for each of these types:
impl_increment!{u8, u16, u32, u64, i8, i16, i32, i64}

fn main() {
    let x = 41u32;
    assert_eq!(x.increment(), Some(42));
    let y = -60_000i64;
    assert_eq!(y.increment(), Some(-59_999));
    let z = 255u8;
    assert_eq!(z.increment(), None);
}

Or you can do it similarly to what you were already doing, by converting to and from a common type. In this case by using the Into<u64> and TryFrom<u64> traits which all the unsigned integer types narrower than u64 implement:

use std::convert::TryFrom;

impl<T> Increment for T
where T: Copy + Into<u64> + TryFrom<u64>
{
    fn increment(&self) -> Option<Self> {
        let padded: u64 = (*self).into();
        TryFrom::try_from(padded + 1).ok()
    }
}

fn main() {
    let x = 41u32;
    assert_eq!(x.increment(), Some(42));
    let y = 60_000u64;
    assert_eq!(y.increment(), Some(60_001));
    let z = 255u8;
    assert_eq!(z.increment(), None);
}

This has a lot more runtime overhead and doesn't generalize as nicely (won't work for signed integer types for example). So I'd go with the macro route.

Blanket implementations can be used to implement traits for all types which satisfy some other trait(s). I'm not sure exactly what the trait in your example is meant to describe, but I hope the following example illustrates the idea.

use std::ops::Add;
use num::traits::One;

trait Increment {
    fn increment(&self) -> Option<String>;
}

impl<T> Increment for T
    where T: Add + Copy + One,
          <T as Add>::Output: ToString,
{
    fn increment(&self) -> Option<String> {
        Some((*self + One::one()).to_string())
    }
}

fn main() {
    let number_u8: u8 = 10;
    println!("number_u8 is: {}", number_u8.increment().unwrap());
    let number_u32: u16 = 10;
    println!("number_u16 is: {}", number_u32.increment().unwrap());
    let number_u32: u32 = 10;
    println!("number_u32 is: {}", number_u32.increment().unwrap());
    let number_u64: u64 = 10;
    println!("number_u64 is: {}", number_u64.increment().unwrap());
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM