1use crate::instance::Bound;
2use crate::types::any::PyAnyMethods;
3use crate::types::PySequence;
4use crate::{
5 err::DowncastError, ffi, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python,
6 ToPyObject,
7};
8use crate::{exceptions, PyErr};
9
10impl<T, const N: usize> IntoPy<PyObject> for [T; N]
11where
12 T: IntoPy<PyObject>,
13{
14 fn into_py(self, py: Python<'_>) -> PyObject {
15 unsafe {
16 let len = N as ffi::Py_ssize_t;
17
18 let ptr = ffi::PyList_New(len);
19
20 let list: Py<PyAny> = Py::from_owned_ptr(py, ptr);
24
25 for (i, obj) in (0..len).zip(self) {
26 let obj = obj.into_py(py).into_ptr();
27
28 #[cfg(not(Py_LIMITED_API))]
29 ffi::PyList_SET_ITEM(ptr, i, obj);
30 #[cfg(Py_LIMITED_API)]
31 ffi::PyList_SetItem(ptr, i, obj);
32 }
33
34 list
35 }
36 }
37}
38
39impl<T, const N: usize> ToPyObject for [T; N]
40where
41 T: ToPyObject,
42{
43 fn to_object(&self, py: Python<'_>) -> PyObject {
44 self.as_ref().to_object(py)
45 }
46}
47
48impl<'py, T, const N: usize> FromPyObject<'py> for [T; N]
49where
50 T: FromPyObject<'py>,
51{
52 fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
53 create_array_from_obj(obj)
54 }
55}
56
57fn create_array_from_obj<'py, T, const N: usize>(obj: &Bound<'py, PyAny>) -> PyResult<[T; N]>
58where
59 T: FromPyObject<'py>,
60{
61 let seq = unsafe {
64 if ffi::PySequence_Check(obj.as_ptr()) != 0 {
65 obj.downcast_unchecked::<PySequence>()
66 } else {
67 return Err(DowncastError::new(obj, "Sequence").into());
68 }
69 };
70 let seq_len = seq.len()?;
71 if seq_len != N {
72 return Err(invalid_sequence_length(N, seq_len));
73 }
74 array_try_from_fn(|idx| seq.get_item(idx).and_then(|any| any.extract()))
75}
76
77fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
80where
81 F: FnMut(usize) -> Result<T, E>,
82{
83 struct ArrayGuard<T, const N: usize> {
86 dst: *mut T,
87 initialized: usize,
88 }
89
90 impl<T, const N: usize> Drop for ArrayGuard<T, N> {
91 fn drop(&mut self) {
92 debug_assert!(self.initialized <= N);
93 let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
94 unsafe {
95 core::ptr::drop_in_place(initialized_part);
96 }
97 }
98 }
99
100 let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
103 let mut guard: ArrayGuard<T, N> = ArrayGuard {
104 dst: array.as_mut_ptr() as _,
105 initialized: 0,
106 };
107 unsafe {
108 let mut value_ptr = array.as_mut_ptr() as *mut T;
109 for i in 0..N {
110 core::ptr::write(value_ptr, cb(i)?);
111 value_ptr = value_ptr.offset(1);
112 guard.initialized += 1;
113 }
114 core::mem::forget(guard);
115 Ok(array.assume_init())
116 }
117}
118
119fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
120 exceptions::PyValueError::new_err(format!(
121 "expected a sequence of length {} (got {})",
122 expected, actual
123 ))
124}
125
126#[cfg(test)]
127mod tests {
128 use std::{
129 panic,
130 sync::atomic::{AtomicUsize, Ordering},
131 };
132
133 use crate::types::any::PyAnyMethods;
134 use crate::{types::PyList, IntoPy, PyResult, Python, ToPyObject};
135
136 #[test]
137 fn array_try_from_fn() {
138 static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
139 struct CountDrop;
140 impl Drop for CountDrop {
141 fn drop(&mut self) {
142 DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
143 }
144 }
145 let _ = catch_unwind_silent(move || {
146 let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
147 #[allow(clippy::manual_assert)]
148 if idx == 2 {
149 panic!("peek a boo");
150 }
151 Ok(CountDrop)
152 });
153 });
154 assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
155 }
156
157 #[test]
158 fn test_extract_bytearray_to_array() {
159 Python::with_gil(|py| {
160 let v: [u8; 33] = py
161 .eval_bound(
162 "bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
163 None,
164 None,
165 )
166 .unwrap()
167 .extract()
168 .unwrap();
169 assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
170 })
171 }
172
173 #[test]
174 fn test_extract_small_bytearray_to_array() {
175 Python::with_gil(|py| {
176 let v: [u8; 3] = py
177 .eval_bound("bytearray(b'abc')", None, None)
178 .unwrap()
179 .extract()
180 .unwrap();
181 assert!(&v == b"abc");
182 });
183 }
184 #[test]
185 fn test_topyobject_array_conversion() {
186 Python::with_gil(|py| {
187 let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
188 let pyobject = array.to_object(py);
189 let pylist = pyobject.downcast_bound::<PyList>(py).unwrap();
190 assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
191 assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
192 assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
193 assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
194 });
195 }
196
197 #[test]
198 fn test_extract_invalid_sequence_length() {
199 Python::with_gil(|py| {
200 let v: PyResult<[u8; 3]> = py
201 .eval_bound("bytearray(b'abcdefg')", None, None)
202 .unwrap()
203 .extract();
204 assert_eq!(
205 v.unwrap_err().to_string(),
206 "ValueError: expected a sequence of length 3 (got 7)"
207 );
208 })
209 }
210
211 #[test]
212 fn test_intopy_array_conversion() {
213 Python::with_gil(|py| {
214 let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
215 let pyobject = array.into_py(py);
216 let pylist = pyobject.downcast_bound::<PyList>(py).unwrap();
217 assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
218 assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
219 assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
220 assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
221 });
222 }
223
224 #[test]
225 fn test_extract_non_iterable_to_array() {
226 Python::with_gil(|py| {
227 let v = py.eval_bound("42", None, None).unwrap();
228 v.extract::<i32>().unwrap();
229 v.extract::<[i32; 1]>().unwrap_err();
230 });
231 }
232
233 #[cfg(feature = "macros")]
234 #[test]
235 fn test_pyclass_intopy_array_conversion() {
236 #[crate::pyclass(crate = "crate")]
237 struct Foo;
238
239 Python::with_gil(|py| {
240 let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo];
241 let pyobject = array.into_py(py);
242 let list = pyobject.downcast_bound::<PyList>(py).unwrap();
243 let _bound = list.get_item(4).unwrap().downcast::<Foo>().unwrap();
244 });
245 }
246
247 fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
249 where
250 F: FnOnce() -> R + panic::UnwindSafe,
251 {
252 let prev_hook = panic::take_hook();
253 panic::set_hook(Box::new(|_| {}));
254 let result = panic::catch_unwind(f);
255 panic::set_hook(prev_hook);
256 result
257 }
258}