简体   繁体   中英

Accessing global state from a callback in rust

I am learning rust and I've been trying to rewrite a project that I did in C# in rust, and I got stuck trying to access global state from a callback,

Is there a simple way to do that in rust? Keeping in mind the I can't add new parameter to the callback.

eg:

use std::collections::HashMap;
use std::time::instant;
use lib::bar;

struct Struct2{
    foo: bar,
    word: String,
}

struct GlobalState{
    running: bool,
    now: Instant,
    map: HashMap<String, Struct2>,
    hook_id: usize,
    current_id: String,
}

impl GlobalState{
    fn init() -> Self{
        let hook_id = unsafe {set_ext_hook(HOOK_ID, system_hook)};
        // Omitted initialization code.
    }
    // Omitted state mutation functions.
}

unsafe extern "system" fn system_hook(ext_param1:usize, ext_param2: usize) -> isize {
    // use global state here
}

I tried using crates such as lazy_static and once_cell, but they didn't work because the external struct that I use (lib::bar in this example) "cannot be sent between threads safely"

My code so far is single threaded (I plan on using a different thread for the program's gui when I implement it)

Any help is appreciated, thanks.

You seem to be dealing with data that is neither Send nor Sync , so Rust won't allow you to place it in a global, even inside a mutex. It's not clear from the question whether this is a result of lib::bar being genunely thread-unsafe, or just the unintended consequence of its use of raw pointers under the hood. It is also unclear whether you are in the position to modify lib::bar to make its types Send and Sync .

Assuming most conservatively that lib::bar cannot be changed, and taking into account that your program is single-threaded, your only safe option is to create a thread-local state:

use std::cell::RefCell;
use std::thread_local;

struct Foo(*const i32); // a non-Send/Sync type

struct GlobalState {
    foo: Foo,
    data: String,
    mutable_data: RefCell<String>,
}

thread_local! {
    static STATE: GlobalState = GlobalState {
        foo: Foo(std::ptr::null()),
        data: "bla".to_string(),
        mutable_data: RefCell::new("".to_string()),
    };
}

You can access that state (and modify its interior-mutable pieces) from any function:

fn main() {
    STATE.with(|state| {
        assert_eq!(state.foo.0, std::ptr::null());
        assert_eq!(state.data, "bla");
        assert_eq!(state.mutable_data.borrow().as_str(), "");
        state.mutable_data.borrow_mut().push_str("xyzzy");
    });
    STATE.with(|state| {
        assert_eq!(state.mutable_data.borrow().as_str(), "xyzzy");
    });
}

Playground

Note that if you try to access the "global" state from different threads, each will get its own copy of the state:

fn main() {
    STATE.with(|state| {
        state.mutable_data.borrow_mut().push_str("xyzzy");
    });
    std::thread::spawn(|| {
        STATE.with(|state| {
            // change to "xyzzy" happened on the other copy
            assert_eq!(state.mutable_data.borrow().as_str(), "");
        })
    })
    .join()
    .unwrap();
}

Playground

One option is "thread confinement" of your variable. This means that all access to the variable happens on one thread. Typically you create a dedicated thread for this and create a proxy for your variable that is shared between other threads and is responsible for getting messages to and from the confining thread.

In rust this is kind of inter-thread communication is usually done using channels. I'll show a cut-down version of your code - where lib::bar simply wraps an i32 pointer. Pointers do not implement Send+Sync and are a pretty good stand-in for your API.

The code is fairly verbose, and I've cheated and not implemented error handling on all the send and recv calls, which you should definitely do. Despite the verbosity, adding new functionality is pretty simple - it mostly consists of adding a variant to the Message and Reply enums, and copying the existing functionality.

use lazy_static::lazy_static;
use std::sync::mpsc::sync_channel;

pub mod lib {
    pub struct Bar(*mut i32);
    impl Bar {
        pub fn new() -> Self {
            Bar(Box::into_raw(Box::new(0)))
        }
        pub fn set(&mut self, v: i32) {
            unsafe { *self.0 = v };
        }
        pub fn get(&self) -> i32 {
            unsafe { *self.0 }
        }
    }
}

enum Message {
    Set(i32),
    Get,
    Shutdown,
}

enum Reply {
    Set,
    Get(i32),
    Shutdown,
}

fn confinement_thread(
    receiver: std::sync::mpsc::Receiver<(Message, std::sync::mpsc::SyncSender<Reply>)>,
) {
    // Create the confined state
    let mut bar = lib::Bar::new();

    // Handle messages and forward them
    loop {
        let (mesg, reply_channel) = receiver.recv().unwrap();
        match mesg {
            Message::Set(v) => {
                eprintln!("    worker: setting value to {}", v);
                bar.set(v);
                reply_channel.send(Reply::Set).unwrap();
            }
            Message::Get => {
                let v = bar.get();
                eprintln!("    worker: getting value = {}", v);
                reply_channel.send(Reply::Get(v)).unwrap();
            }
            Message::Shutdown => {
                eprintln!("    worker: shutting down");
                reply_channel.send(Reply::Shutdown).unwrap();
                break;
            }
        }
    }
}

// This can be cloned happily
// and supports Send+Sync
struct GlobalProxy {
    channel: std::sync::mpsc::SyncSender<(Message, std::sync::mpsc::SyncSender<Reply>)>,
}

impl GlobalProxy {
    pub fn set(&self, v: i32) {
        eprintln!("  proxy: setting value to {}", v);
        let (a, b) = sync_channel(0);
        self.channel.send((Message::Set(v), a)).unwrap();
        let m = b.recv().unwrap();
        assert!(matches!(m, Reply::Set));
    }

    pub fn get(&self) -> i32 {
        eprintln!("  proxy: getting value");
        let (a, b) = sync_channel(0);
        self.channel.send((Message::Get, a)).unwrap();
        let m = b.recv().unwrap();
        if let Reply::Get(v) = m {
            eprintln!("  proxy: got value={}", v);
            v
        } else {
            unreachable!();
        }
    }

    pub fn die(&self) {
        eprintln!("Telling worker thread to shut down");
        let (a, b) = sync_channel(0);
        self.channel.send((Message::Shutdown, a)).unwrap();
        let m = b.recv().unwrap();
        assert!(matches!(m, Reply::Shutdown));
    }
}

lazy_static! {
    static ref G: GlobalProxy = {
        // Create com channels
        let (to_global, from_world) = sync_channel(0);
        // Keep one end for the proxy,
        let global = GlobalProxy{ channel: to_global};
        // The other goes to the worker thread
        std::thread::spawn(|| {confinement_thread(from_world)});
        global
    };
}

pub fn main() {
    eprintln!("global.get() = {}", G.get());
    eprintln!("global.set(10)",);
    G.set(10);
    eprintln!("global.get() = {}", G.get());

    G.die()
}

Working version

There is probably a lot of opportunities to make this less verbose using macros, but I find this version more instructive.

Another improvement would be to put the reply channel into the message object - which would allow us to remove the Reply enum.

In some cases, it may be possible to remove the Message object, by passing a function to the confinement thread to run, rather than a message. Something like:

impl GlobalProxy { 
  fn run_confined(&self f: dyn Fn(&lib::Bar) + Send + Sync)
   {...}
}

But handling functions with return values in a nice way is tricky.

I have worked example not from global state, but accessing variables from outer scope. Maybe it will useful.

use std::collections::HashMap;

fn main() {
    let a = String::from("Hello world");
    let b = String::from("Another world");
    let mut keys: HashMap<String, String> = HashMap::new();

    let callback = |line: String| {
        keys.insert(line.to_string(), line.to_string());
        println!("{}", b);
        println!("{}", line);
        println!("{:?}", keys);
    };

    compute(a, callback)
}

fn compute<F>(a: String, mut f: F)
    where F: FnMut(String)
{
    f(a)
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM