简体   繁体   中英

Writing a pyo3 function equivalent to a Python function that returns its input object

I am looking to write a Rust backend for my library, and I need to implement the equivalent of the following function in pyo3 :

def f(x):
    return x

This should return the same object as the input, and the function getting the return value should hold a new reference to the input. If I were writing this in the C API I would write it as:

PyObject * f(PyObject * x) {
    Py_XINCREF(x);
    return x;
}

In PyO3 , I find it quite confusing to navigate the differences between PyObject , PyObjectRef , &PyObject , Py<PyObject> , Py<&PyObject> .

The most naive version of this function is:

extern crate pyo3;

use pyo3::prelude::*;

#[pyfunction]
pub fn f(_py: Python, x: &PyObject) -> PyResult<&PyObject> {
    Ok(x)
}

Among other things, the lifetimes of x and the return value are not the same, plus I see no opportunity for pyo3 to increase the reference count for x , and in fact the compiler seems to agree with me:

error[E0106]: missing lifetime specifier
 --> src/lib.rs:4:49
  |
4 | pub fn f(_py: Python, x: &PyObject) -> PyResult<&PyObject> {
  |                                                 ^ expected lifetime parameter
  |
  = help: this function's return type contains a borrowed value, but the signature does not say whether it is borrowed from `_py` or `x`

There may be a way for me to manually increase the reference count using the _py parameter and use lifetime annotations to make the compiler happy, but my impression is that pyo3 intends to manage reference counts itself using object lifetimes.

What is the proper way to write this function? Should I be attempting to wrap it in a Py container?

A PyObject is a simple wrapper around a raw pointer :

pub struct PyObject(*mut ffi::PyObject);

It has multiple creation functions, each corresponding to different kinds of pointers that we might get from Python. Some of these, such as from_borrowed_ptr , call Py_INCREF on the passed-in pointer.

Thus, it seems like we can accept a PyObject , so long as it was created in the "right" manner.

If we expand this code:

#[pyfunction]
pub fn example(_py: Python, x: PyObject) -> PyObject {
    x
}

We can see this section of code that calls our function:

let mut _iter = _output.iter();
::pyo3::ObjectProtocol::extract(_iter.next().unwrap().unwrap()).and_then(
    |arg1| {
        ::pyo3::ReturnTypeIntoPyResult::return_type_into_py_result(example(
            _py, arg1,
        ))
    },
)

Our argument is created by a call to ObjectProtocol::extract , which in turn calls FromPyObject::extract . This is implemented for PyObject by calling from_borrowed_ptr .

Thus, using a bare PyObject as the argument type will correctly increment the reference count.

Likewise, when a PyObject is dropped in Rust, it will automatically decrease the reference count . When it is returned back to Python, ownership is transferred and it is up to the Python code to update the reference count appropriately.


All investigation done for commit ed273982 from the master branch, corresponding to v0.5.0-alpha.1.

According to the other answer , pyo3 takes care of building additional boilerplate around our functions in order to keep track of Python reference counting. In particular, the counter is already incremented when passing the object as an argument to the function. Nevertheless, the clone_ref method can be used to explicitly create a new reference to the same object, which will also increment its reference counter.

The output of the function must still be an actual Python object rather than a reference to it (which seems reasonable, as Python does not understand Rust references; pyo3 seems to ignore lifetime parameters in these functions).

#[pyfunction]
fn f(py: Python, x: PyObject) -> PyResult<PyObject> {
    Ok(x.clone_ref(py))
}

From playing around with the function in Python land (AKA not a serious testbed), it at least seems to work as intended.

from dummypy import f

def get_object():
    return f("OK")

a = [1, 2, 3]

if True:
    b = f(a)
    assert b is a
    b[0] = 9001

print(a)

x = get_object()
print(x)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM