pyo3/conversions/std/
array.rs

1use crate::instance::Bound;
2use crate::types::any::PyAnyMethods;
3use crate::types::PySequence;
4use crate::{
5    err::DowncastError, ffi, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python,
6    ToPyObject,
7};
8use crate::{exceptions, PyErr};
9
10impl<T, const N: usize> IntoPy<PyObject> for [T; N]
11where
12    T: IntoPy<PyObject>,
13{
14    fn into_py(self, py: Python<'_>) -> PyObject {
15        unsafe {
16            let len = N as ffi::Py_ssize_t;
17
18            let ptr = ffi::PyList_New(len);
19
20            // We create the  `Py` pointer here for two reasons:
21            // - panics if the ptr is null
22            // - its Drop cleans up the list if user code panics.
23            let list: Py<PyAny> = Py::from_owned_ptr(py, ptr);
24
25            for (i, obj) in (0..len).zip(self) {
26                let obj = obj.into_py(py).into_ptr();
27
28                #[cfg(not(Py_LIMITED_API))]
29                ffi::PyList_SET_ITEM(ptr, i, obj);
30                #[cfg(Py_LIMITED_API)]
31                ffi::PyList_SetItem(ptr, i, obj);
32            }
33
34            list
35        }
36    }
37}
38
39impl<T, const N: usize> ToPyObject for [T; N]
40where
41    T: ToPyObject,
42{
43    fn to_object(&self, py: Python<'_>) -> PyObject {
44        self.as_ref().to_object(py)
45    }
46}
47
48impl<'py, T, const N: usize> FromPyObject<'py> for [T; N]
49where
50    T: FromPyObject<'py>,
51{
52    fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
53        create_array_from_obj(obj)
54    }
55}
56
57fn create_array_from_obj<'py, T, const N: usize>(obj: &Bound<'py, PyAny>) -> PyResult<[T; N]>
58where
59    T: FromPyObject<'py>,
60{
61    // Types that pass `PySequence_Check` usually implement enough of the sequence protocol
62    // to support this function and if not, we will only fail extraction safely.
63    let seq = unsafe {
64        if ffi::PySequence_Check(obj.as_ptr()) != 0 {
65            obj.downcast_unchecked::<PySequence>()
66        } else {
67            return Err(DowncastError::new(obj, "Sequence").into());
68        }
69    };
70    let seq_len = seq.len()?;
71    if seq_len != N {
72        return Err(invalid_sequence_length(N, seq_len));
73    }
74    array_try_from_fn(|idx| seq.get_item(idx).and_then(|any| any.extract()))
75}
76
77// TODO use std::array::try_from_fn, if that stabilises:
78// (https://github.com/rust-lang/rust/issues/89379)
79fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
80where
81    F: FnMut(usize) -> Result<T, E>,
82{
83    // Helper to safely create arrays since the standard library doesn't
84    // provide one yet. Shouldn't be necessary in the future.
85    struct ArrayGuard<T, const N: usize> {
86        dst: *mut T,
87        initialized: usize,
88    }
89
90    impl<T, const N: usize> Drop for ArrayGuard<T, N> {
91        fn drop(&mut self) {
92            debug_assert!(self.initialized <= N);
93            let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
94            unsafe {
95                core::ptr::drop_in_place(initialized_part);
96            }
97        }
98    }
99
100    // [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
101    // APIs which would make this easier.
102    let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
103    let mut guard: ArrayGuard<T, N> = ArrayGuard {
104        dst: array.as_mut_ptr() as _,
105        initialized: 0,
106    };
107    unsafe {
108        let mut value_ptr = array.as_mut_ptr() as *mut T;
109        for i in 0..N {
110            core::ptr::write(value_ptr, cb(i)?);
111            value_ptr = value_ptr.offset(1);
112            guard.initialized += 1;
113        }
114        core::mem::forget(guard);
115        Ok(array.assume_init())
116    }
117}
118
119fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
120    exceptions::PyValueError::new_err(format!(
121        "expected a sequence of length {} (got {})",
122        expected, actual
123    ))
124}
125
126#[cfg(test)]
127mod tests {
128    use std::{
129        panic,
130        sync::atomic::{AtomicUsize, Ordering},
131    };
132
133    use crate::types::any::PyAnyMethods;
134    use crate::{types::PyList, IntoPy, PyResult, Python, ToPyObject};
135
136    #[test]
137    fn array_try_from_fn() {
138        static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
139        struct CountDrop;
140        impl Drop for CountDrop {
141            fn drop(&mut self) {
142                DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
143            }
144        }
145        let _ = catch_unwind_silent(move || {
146            let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
147                #[allow(clippy::manual_assert)]
148                if idx == 2 {
149                    panic!("peek a boo");
150                }
151                Ok(CountDrop)
152            });
153        });
154        assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
155    }
156
157    #[test]
158    fn test_extract_bytearray_to_array() {
159        Python::with_gil(|py| {
160            let v: [u8; 33] = py
161                .eval_bound(
162                    "bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
163                    None,
164                    None,
165                )
166                .unwrap()
167                .extract()
168                .unwrap();
169            assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
170        })
171    }
172
173    #[test]
174    fn test_extract_small_bytearray_to_array() {
175        Python::with_gil(|py| {
176            let v: [u8; 3] = py
177                .eval_bound("bytearray(b'abc')", None, None)
178                .unwrap()
179                .extract()
180                .unwrap();
181            assert!(&v == b"abc");
182        });
183    }
184    #[test]
185    fn test_topyobject_array_conversion() {
186        Python::with_gil(|py| {
187            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
188            let pyobject = array.to_object(py);
189            let pylist = pyobject.downcast_bound::<PyList>(py).unwrap();
190            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
191            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
192            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
193            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
194        });
195    }
196
197    #[test]
198    fn test_extract_invalid_sequence_length() {
199        Python::with_gil(|py| {
200            let v: PyResult<[u8; 3]> = py
201                .eval_bound("bytearray(b'abcdefg')", None, None)
202                .unwrap()
203                .extract();
204            assert_eq!(
205                v.unwrap_err().to_string(),
206                "ValueError: expected a sequence of length 3 (got 7)"
207            );
208        })
209    }
210
211    #[test]
212    fn test_intopy_array_conversion() {
213        Python::with_gil(|py| {
214            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
215            let pyobject = array.into_py(py);
216            let pylist = pyobject.downcast_bound::<PyList>(py).unwrap();
217            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
218            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
219            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
220            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
221        });
222    }
223
224    #[test]
225    fn test_extract_non_iterable_to_array() {
226        Python::with_gil(|py| {
227            let v = py.eval_bound("42", None, None).unwrap();
228            v.extract::<i32>().unwrap();
229            v.extract::<[i32; 1]>().unwrap_err();
230        });
231    }
232
233    #[cfg(feature = "macros")]
234    #[test]
235    fn test_pyclass_intopy_array_conversion() {
236        #[crate::pyclass(crate = "crate")]
237        struct Foo;
238
239        Python::with_gil(|py| {
240            let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo];
241            let pyobject = array.into_py(py);
242            let list = pyobject.downcast_bound::<PyList>(py).unwrap();
243            let _bound = list.get_item(4).unwrap().downcast::<Foo>().unwrap();
244        });
245    }
246
247    // https://stackoverflow.com/a/59211505
248    fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
249    where
250        F: FnOnce() -> R + panic::UnwindSafe,
251    {
252        let prev_hook = panic::take_hook();
253        panic::set_hook(Box::new(|_| {}));
254        let result = panic::catch_unwind(f);
255        panic::set_hook(prev_hook);
256        result
257    }
258}