1use std::{
2    collections::VecDeque,
3    io::{Read, Write},
4};
5
6use tracing::warn;
7
8use crate::{pdu::reader::PDU_HEADER_SIZE, read_pdu, Pdu};
9
10#[must_use]
54pub struct PDataWriter<W: Write> {
55    buffer: Vec<u8>,
56    stream: W,
57    max_data_len: u32,
58}
59
60impl<W> PDataWriter<W>
61where
62    W: Write,
63{
64    pub(crate) fn new(stream: W, presentation_context_id: u8, max_pdu_length: u32) -> Self {
68        let max_data_length = calculate_max_data_len_single(max_pdu_length);
69        let mut buffer = Vec::with_capacity((max_data_length + PDU_HEADER_SIZE) as usize);
70        buffer.extend([
72            0x04,
74            0x00,
75            0xFF,
77            0xFF,
78            0xFF,
79            0xFF,
80            0xFF,
82            0xFF,
83            0xFF,
84            0xFF,
85            presentation_context_id,
87            0xFF,
89        ]);
90
91        PDataWriter {
92            stream,
93            max_data_len: max_data_length,
94            buffer,
95        }
96    }
97
98    pub fn finish(mut self) -> std::io::Result<()> {
103        self.finish_impl()?;
104        Ok(())
105    }
106
107    fn setup_pdata_header(&mut self, is_last: bool) {
109        let data_len = (self.buffer.len() - 12) as u32;
110
111        let pdu_len = data_len + 4 + 2;
113        let pdu_len_bytes = pdu_len.to_be_bytes();
114
115        self.buffer[2] = pdu_len_bytes[0];
116        self.buffer[3] = pdu_len_bytes[1];
117        self.buffer[4] = pdu_len_bytes[2];
118        self.buffer[5] = pdu_len_bytes[3];
119
120        let pdv_data_len = data_len + 2;
122        let data_len_bytes = pdv_data_len.to_be_bytes();
123
124        self.buffer[6] = data_len_bytes[0];
125        self.buffer[7] = data_len_bytes[1];
126        self.buffer[8] = data_len_bytes[2];
127        self.buffer[9] = data_len_bytes[3];
128
129        self.buffer[11] = if is_last { 0x02 } else { 0x00 };
131    }
132
133    fn finish_impl(&mut self) -> std::io::Result<()> {
134        if !self.buffer.is_empty() {
135            self.setup_pdata_header(true);
137            self.stream.write_all(&self.buffer[..])?;
138            self.buffer.clear();
141        }
142        Ok(())
143    }
144
145    fn dispatch_pdu(&mut self) -> std::io::Result<()> {
150        debug_assert!(self.buffer.len() >= 12);
151        self.setup_pdata_header(false);
153        self.stream.write_all(&self.buffer)?;
154
155        self.buffer.truncate(12);
157
158        Ok(())
159    }
160}
161
162impl<W> Write for PDataWriter<W>
163where
164    W: Write,
165{
166    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
167        let total_len = self.max_data_len as usize + 12;
168        if self.buffer.len() + buf.len() <= total_len {
169            self.buffer.extend(buf);
171            Ok(buf.len())
172        } else {
173            let buf = &buf[..total_len - self.buffer.len()];
176            self.buffer.extend(buf);
177            debug_assert_eq!(self.buffer.len(), total_len);
178            self.dispatch_pdu()?;
179            Ok(buf.len())
180        }
181    }
182
183    fn flush(&mut self) -> std::io::Result<()> {
184        Ok(())
186    }
187}
188
189impl<W> Drop for PDataWriter<W>
194where
195    W: Write,
196{
197    fn drop(&mut self) {
198        let _ = self.finish_impl();
199    }
200}
201
202#[must_use]
236pub struct PDataReader<R> {
237    buffer: VecDeque<u8>,
238    stream: R,
239    presentation_context_id: Option<u8>,
240    max_data_length: u32,
241    last_pdu: bool,
242}
243
244impl<R> PDataReader<R>
245where
246    R: Read,
247{
248    pub fn new(stream: R, max_data_length: u32) -> Self {
249        PDataReader {
250            buffer: VecDeque::with_capacity(max_data_length as usize),
251            stream,
252            presentation_context_id: None,
253            max_data_length,
254            last_pdu: false,
255        }
256    }
257
258    pub fn stop_receiving(&mut self) -> std::io::Result<()> {
264        self.last_pdu = true;
265        Ok(())
266    }
267}
268
269impl<R> Read for PDataReader<R>
270where
271    R: Read,
272{
273    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
274        if self.buffer.is_empty() {
275            if self.last_pdu {
276                return Ok(0);
278            }
279
280            let pdu = read_pdu(&mut self.stream, self.max_data_length, false)
281                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
282
283            match pdu {
284                Pdu::PData { data } => {
285                    for pdata_value in data {
286                        self.presentation_context_id = match self.presentation_context_id {
287                            None => Some(pdata_value.presentation_context_id),
288                            Some(cid) if cid == pdata_value.presentation_context_id => Some(cid),
289                            Some(cid) => {
290                                warn!("Received PData value of presentation context {}, but should be {}", pdata_value.presentation_context_id, cid);
291                                Some(cid)
292                            }
293                        };
294                        self.buffer.extend(pdata_value.data);
295                        self.last_pdu = pdata_value.is_last;
296                    }
297                }
298                _ => {
299                    return Err(std::io::Error::new(
300                        std::io::ErrorKind::UnexpectedEof,
301                        "Unexpected PDU type",
302                    ))
303                }
304            }
305        }
306        Read::read(&mut self.buffer, buf)
307    }
308}
309
310#[inline]
314fn calculate_max_data_len_single(pdu_len: u32) -> u32 {
315    pdu_len - 4 - 2
318}
319
320#[cfg(test)]
321mod tests {
322    use std::collections::VecDeque;
323    use std::io::{Read, Write};
324
325    use crate::pdu::reader::{read_pdu, MINIMUM_PDU_SIZE, PDU_HEADER_SIZE};
326    use crate::pdu::Pdu;
327    use crate::pdu::{PDataValue, PDataValueType};
328    use crate::write_pdu;
329
330    use super::{PDataReader, PDataWriter};
331
332    #[test]
333    fn test_write_pdata_and_finish() {
334        let presentation_context_id = 12;
335
336        let mut buf = Vec::new();
337        {
338            let mut writer = PDataWriter::new(&mut buf, presentation_context_id, MINIMUM_PDU_SIZE);
339            writer.write_all(&(0..64).collect::<Vec<u8>>()).unwrap();
340            writer.finish().unwrap();
341        }
342
343        let mut cursor = &buf[..];
344        let same_pdu = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
345
346        match same_pdu {
349            Pdu::PData { data: data_1 } => {
350                let data_1 = &data_1[0];
351
352                assert_eq!(data_1.value_type, PDataValueType::Data);
354                assert_eq!(data_1.presentation_context_id, presentation_context_id);
355                assert_eq!(data_1.data.len(), 64);
356                assert_eq!(data_1.data, (0..64).collect::<Vec<u8>>());
357            }
358            pdu => panic!("Expected PData, got {:?}", pdu),
359        }
360
361        assert_eq!(cursor.len(), 0);
362    }
363
364    #[test]
365    fn test_write_large_pdata_and_finish() {
366        let presentation_context_id = 32;
367
368        let my_data: Vec<_> = (0..9000).map(|x: u32| x as u8).collect();
369
370        let mut buf = Vec::new();
371        {
372            let mut writer = PDataWriter::new(&mut buf, presentation_context_id, MINIMUM_PDU_SIZE);
373            writer.write_all(&my_data).unwrap();
374            writer.finish().unwrap();
375        }
376
377        let mut cursor = &buf[..];
378        let pdu_1 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
379        let pdu_2 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
380        let pdu_3 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
381
382        match (pdu_1, pdu_2, pdu_3) {
385            (
386                Pdu::PData { data: data_1 },
387                Pdu::PData { data: data_2 },
388                Pdu::PData { data: data_3 },
389            ) => {
390                assert_eq!(data_1.len(), 1);
391                let data_1 = &data_1[0];
392                assert_eq!(data_2.len(), 1);
393                let data_2 = &data_2[0];
394                assert_eq!(data_3.len(), 1);
395                let data_3 = &data_3[0];
396
397                assert_eq!(data_1.value_type, PDataValueType::Data);
399                assert_eq!(data_2.value_type, PDataValueType::Data);
400                assert_eq!(data_1.presentation_context_id, presentation_context_id);
401                assert_eq!(data_2.presentation_context_id, presentation_context_id);
402
403                assert_eq!(
405                    data_1.data.len(),
406                    (MINIMUM_PDU_SIZE - PDU_HEADER_SIZE) as usize
407                );
408                assert_eq!(
409                    data_2.data.len(),
410                    (MINIMUM_PDU_SIZE - PDU_HEADER_SIZE) as usize
411                );
412                assert_eq!(data_3.data.len(), 820);
413
414                assert_eq!(
416                    &data_1.data[..],
417                    (0..MINIMUM_PDU_SIZE - PDU_HEADER_SIZE)
418                        .map(|x| x as u8)
419                        .collect::<Vec<_>>()
420                );
421                assert_eq!(
422                    data_1.data.len() + data_2.data.len() + data_3.data.len(),
423                    9000
424                );
425
426                let data_1 = &data_1.data;
427                let data_2 = &data_2.data;
428                let data_3 = &data_3.data;
429
430                let mut all_data: Vec<u8> = Vec::new();
431                all_data.extend(data_1);
432                all_data.extend(data_2);
433                all_data.extend(data_3);
434                assert_eq!(all_data, my_data);
435            }
436            x => panic!("Expected 3 PDatas, got {:?}", x),
437        }
438
439        assert_eq!(cursor.len(), 0);
440    }
441
442    #[test]
443    fn test_read_large_pdata_and_finish() {
444        let presentation_context_id = 32;
445
446        let my_data: Vec<_> = (0..9000).map(|x: u32| x as u8).collect();
447        let pdata_1 = vec![PDataValue {
448            value_type: PDataValueType::Data,
449            data: my_data[0..3000].to_owned(),
450            presentation_context_id,
451            is_last: false,
452        }];
453        let pdata_2 = vec![PDataValue {
454            value_type: PDataValueType::Data,
455            data: my_data[3000..6000].to_owned(),
456            presentation_context_id,
457            is_last: false,
458        }];
459        let pdata_3 = vec![PDataValue {
460            value_type: PDataValueType::Data,
461            data: my_data[6000..].to_owned(),
462            presentation_context_id,
463            is_last: true,
464        }];
465
466        let mut pdu_stream = VecDeque::new();
467
468        write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_1 }).unwrap();
470        write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_2 }).unwrap();
471        write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_3 }).unwrap();
472
473        let mut buf = Vec::new();
474        {
475            let mut reader = PDataReader::new(&mut pdu_stream, MINIMUM_PDU_SIZE);
476            reader.read_to_end(&mut buf).unwrap();
477        }
478        assert_eq!(buf, my_data);
479    }
480}