1use std::{
8    borrow::Cow,
9    convert::TryInto,
10    io::Write,
11    net::{TcpStream, ToSocketAddrs}, time::Duration,
12};
13
14use crate::{
15    pdu::{
16        reader::{read_pdu, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE},
17        writer::write_pdu,
18        AbortRQSource, AssociationAC, AssociationRJ, AssociationRQ, Pdu,
19        PresentationContextProposed, PresentationContextResult, PresentationContextResultReason,
20        UserIdentity, UserIdentityType, UserVariableItem,
21    },
22    AeAddr, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME,
23};
24use snafu::{ensure, Backtrace, ResultExt, Snafu};
25
26use super::{
27    pdata::{PDataReader, PDataWriter},
28    uid::trim_uid,
29};
30
31#[derive(Debug, Snafu)]
32#[non_exhaustive]
33pub enum Error {
34    MissingAbstractSyntax { backtrace: Backtrace },
36
37    Connect {
39        source: std::io::Error,
40        backtrace: Backtrace,
41    },
42    
43    SetReadTimeout{
45        source: std::io::Error,
46        backtrace: Backtrace,
47    },
48
49    SetWriteTimeout{
51        source: std::io::Error,
52        backtrace: Backtrace,
53    },
54
55    SendRequest {
57        #[snafu(backtrace)]
58        source: crate::pdu::writer::Error,
59    },
60
61    ReceiveResponse {
63        #[snafu(backtrace)]
64        source: crate::pdu::reader::Error,
65    },
66
67    #[snafu(display("unexpected response from server `{:?}`", pdu))]
68    #[non_exhaustive]
69    UnexpectedResponse {
70        pdu: Box<Pdu>,
72    },
73
74    #[snafu(display("unknown response from server `{:?}`", pdu))]
75    #[non_exhaustive]
76    UnknownResponse {
77        pdu: Box<Pdu>,
79    },
80
81    #[snafu(display("protocol version mismatch: expected {}, got {}", expected, got))]
82    ProtocolVersionMismatch {
83        expected: u16,
84        got: u16,
85        backtrace: Backtrace,
86    },
87
88    #[snafu(display("association rejected by the server: {}", association_rj.source))]
89    Rejected {
90        association_rj: AssociationRJ,
91        backtrace: Backtrace,
92    },
93
94    NoAcceptedPresentationContexts { backtrace: Backtrace },
96
97    #[non_exhaustive]
99    Send {
100        #[snafu(backtrace)]
101        source: crate::pdu::writer::Error,
102    },
103
104    #[non_exhaustive]
106    WireSend {
107        source: std::io::Error,
108        backtrace: Backtrace,
109    },
110
111    #[snafu(display(
112        "PDU is too large ({} bytes) to be sent to the remote application entity",
113        length
114    ))]
115    #[non_exhaustive]
116    SendTooLongPdu { length: usize, backtrace: Backtrace },
117
118    #[non_exhaustive]
120    Receive {
121        #[snafu(backtrace)]
122        source: crate::pdu::reader::Error,
123    },
124}
125
126pub type Result<T, E = Error> = std::result::Result<T, E>;
127
128#[derive(Debug, Clone)]
167pub struct ClientAssociationOptions<'a> {
168    calling_ae_title: Cow<'a, str>,
170    called_ae_title: Option<Cow<'a, str>>,
172    application_context_name: Cow<'a, str>,
174    presentation_contexts: Vec<(Cow<'a, str>, Vec<Cow<'a, str>>)>,
176    protocol_version: u16,
178    max_pdu_length: u32,
180    strict: bool,
182    username: Option<Cow<'a, str>>,
184    password: Option<Cow<'a, str>>,
186    kerberos_service_ticket: Option<Cow<'a, str>>,
188    saml_assertion: Option<Cow<'a, str>>,
190    jwt: Option<Cow<'a, str>>,
192    read_timeout: Option<Duration>,
194    write_timeout: Option<Duration>,
196}
197
198impl<'a> Default for ClientAssociationOptions<'a> {
199    fn default() -> Self {
200        ClientAssociationOptions {
201            calling_ae_title: "THIS-SCU".into(),
203            called_ae_title: None,
205            application_context_name: "1.2.840.10008.3.1.1.1".into(),
207            presentation_contexts: Vec::new(),
209            protocol_version: 1,
210            max_pdu_length: crate::pdu::reader::DEFAULT_MAX_PDU,
211            strict: true,
212            username: None,
213            password: None,
214            kerberos_service_ticket: None,
215            saml_assertion: None,
216            jwt: None,
217            read_timeout: None,
218            write_timeout: None,
219        }
220    }
221}
222
223impl<'a> ClientAssociationOptions<'a> {
224    pub fn new() -> Self {
226        Self::default()
227    }
228
229    pub fn calling_ae_title<T>(mut self, calling_ae_title: T) -> Self
234    where
235        T: Into<Cow<'a, str>>,
236    {
237        self.calling_ae_title = calling_ae_title.into();
238        self
239    }
240
241    pub fn called_ae_title<T>(mut self, called_ae_title: T) -> Self
248    where
249        T: Into<Cow<'a, str>>,
250    {
251        let cae = called_ae_title.into();
252        if cae.is_empty() {
253            self.called_ae_title = None;
254        } else {
255            self.called_ae_title = Some(cae);
256        }
257        self
258    }
259
260    pub fn with_presentation_context<T>(
263        mut self,
264        abstract_syntax_uid: T,
265        transfer_syntax_uids: Vec<T>,
266    ) -> Self
267    where
268        T: Into<Cow<'a, str>>,
269    {
270        let transfer_syntaxes: Vec<Cow<'a, str>> = transfer_syntax_uids
271            .into_iter()
272            .map(|t| trim_uid(t.into()))
273            .collect();
274        self.presentation_contexts
275            .push((trim_uid(abstract_syntax_uid.into()), transfer_syntaxes));
276        self
277    }
278
279    pub fn with_abstract_syntax<T>(self, abstract_syntax_uid: T) -> Self
283    where
284        T: Into<Cow<'a, str>>,
285    {
286        let default_transfer_syntaxes: Vec<Cow<'a, str>> =
287            vec!["1.2.840.10008.1.2.1".into(), "1.2.840.10008.1.2".into()];
288        self.with_presentation_context(abstract_syntax_uid.into(), default_transfer_syntaxes)
289    }
290
291    pub fn max_pdu_length(mut self, value: u32) -> Self {
294        self.max_pdu_length = value;
295        self
296    }
297
298    pub fn strict(mut self, strict: bool) -> Self {
302        self.strict = strict;
303        self
304    }
305
306    pub fn username<T>(mut self, username: T) -> Self
308    where
309        T: Into<Cow<'a, str>>,
310    {
311        let username = username.into();
312        if username.is_empty() {
313            self.username = None;
314        } else {
315            self.username = Some(username);
316            self.saml_assertion = None;
317            self.jwt = None;
318            self.kerberos_service_ticket = None;
319        }
320        self
321    }
322
323    pub fn password<T>(mut self, password: T) -> Self
325    where
326        T: Into<Cow<'a, str>>,
327    {
328        let password = password.into();
329        if password.is_empty() {
330            self.password = None;
331        } else {
332            self.password = Some(password);
333            self.saml_assertion = None;
334            self.jwt = None;
335            self.kerberos_service_ticket = None;
336        }
337        self
338    }
339
340    pub fn username_password<T, U>(mut self, username: T, password: U) -> Self
342    where
343        T: Into<Cow<'a, str>>,
344        U: Into<Cow<'a, str>>,
345    {
346        let username = username.into();
347        let password = password.into();
348        if username.is_empty() {
349            self.username = None;
350            self.password = None;
351        } else {
352            self.username = Some(username);
353            self.password = Some(password);
354            self.saml_assertion = None;
355            self.jwt = None;
356            self.kerberos_service_ticket = None;
357        }
358        self
359    }
360
361    pub fn kerberos_service_ticket<T>(mut self, kerberos_service_ticket: T) -> Self
363    where
364        T: Into<Cow<'a, str>>,
365    {
366        let kerberos_service_ticket = kerberos_service_ticket.into();
367        if kerberos_service_ticket.is_empty() {
368            self.kerberos_service_ticket = None;
369        } else {
370            self.kerberos_service_ticket = Some(kerberos_service_ticket);
371            self.username = None;
372            self.password = None;
373            self.saml_assertion = None;
374            self.jwt = None;
375        }
376        self
377    }
378
379    pub fn saml_assertion<T>(mut self, saml_assertion: T) -> Self
381    where
382        T: Into<Cow<'a, str>>,
383    {
384        let saml_assertion = saml_assertion.into();
385        if saml_assertion.is_empty() {
386            self.saml_assertion = None;
387        } else {
388            self.saml_assertion = Some(saml_assertion);
389            self.username = None;
390            self.password = None;
391            self.jwt = None;
392            self.kerberos_service_ticket = None;
393        }
394        self
395    }
396
397    pub fn jwt<T>(mut self, jwt: T) -> Self
399    where
400        T: Into<Cow<'a, str>>,
401    {
402        let jwt = jwt.into();
403        if jwt.is_empty() {
404            self.jwt = None;
405        } else {
406            self.jwt = Some(jwt);
407            self.username = None;
408            self.password = None;
409            self.saml_assertion = None;
410            self.kerberos_service_ticket = None;
411        }
412        self
413    }
414
415    pub fn establish<A: ToSocketAddrs>(self, address: A) -> Result<ClientAssociation> {
419        self.establish_impl(AeAddr::new_socket_addr(address))
420    }
421
422    pub fn establish_with(self, ae_address: &str) -> Result<ClientAssociation> {
446        match ae_address.try_into() {
447            Ok(ae_address) => self.establish_impl(ae_address),
448            Err(_) => self.establish_impl(AeAddr::new_socket_addr(ae_address)),
449        }
450    }
451
452    pub fn read_timeout(self, timeout: Duration) -> Self {
454        Self {
455            read_timeout: Some(timeout),
456            ..self
457        }
458    }
459
460    pub fn write_timeout(self, timeout: Duration) -> Self {
462        Self {
463            write_timeout: Some(timeout),
464            ..self
465        }
466    }
467
468    fn establish_impl<T>(self, ae_address: AeAddr<T>) -> Result<ClientAssociation>
469    where
470        T: ToSocketAddrs,
471    {
472        let ClientAssociationOptions {
473            calling_ae_title,
474            called_ae_title,
475            application_context_name,
476            presentation_contexts,
477            protocol_version,
478            max_pdu_length,
479            strict,
480            username,
481            password,
482            kerberos_service_ticket,
483            saml_assertion,
484            jwt,
485            read_timeout,
486            write_timeout
487        } = self;
488
489        ensure!(
492            !presentation_contexts.is_empty(),
493            MissingAbstractSyntaxSnafu
494        );
495
496        let called_ae_title: &str = match (&called_ae_title, ae_address.ae_title()) {
498            (Some(aec), Some(_)) => {
499                tracing::warn!(
500                    "Option `called_ae_title` overrides the AE title to `{}`",
501                    aec
502                );
503                aec
504            }
505            (Some(aec), None) => aec,
506            (None, Some(aec)) => aec,
507            (None, None) => "ANY-SCP",
508        };
509
510        let presentation_contexts: Vec<_> = presentation_contexts
511            .into_iter()
512            .enumerate()
513            .map(|(i, presentation_context)| PresentationContextProposed {
514                id: (2 * i + 1) as u8,
515                abstract_syntax: presentation_context.0.to_string(),
516                transfer_syntaxes: presentation_context
517                    .1
518                    .iter()
519                    .map(|uid| uid.to_string())
520                    .collect(),
521            })
522            .collect();
523
524        let mut user_variables = vec![
525            UserVariableItem::MaxLength(max_pdu_length),
526            UserVariableItem::ImplementationClassUID(IMPLEMENTATION_CLASS_UID.to_string()),
527            UserVariableItem::ImplementationVersionName(IMPLEMENTATION_VERSION_NAME.to_string()),
528        ];
529
530        if let Some(user_identity) = Self::determine_user_identity(
531            username,
532            password,
533            kerberos_service_ticket,
534            saml_assertion,
535            jwt,
536        ) {
537            user_variables.push(UserVariableItem::UserIdentityItem(user_identity));
538        }
539
540        let msg = Pdu::AssociationRQ(AssociationRQ {
541            protocol_version,
542            calling_ae_title: calling_ae_title.to_string(),
543            called_ae_title: called_ae_title.to_string(),
544            application_context_name: application_context_name.to_string(),
545            presentation_contexts,
546            user_variables,
547        });
548
549        let mut socket = std::net::TcpStream::connect(ae_address)
550            .context(ConnectSnafu)?;
551        socket.set_read_timeout(read_timeout)
552            .context(SetReadTimeoutSnafu)?;
553        socket.set_write_timeout(write_timeout)
554            .context(SetWriteTimeoutSnafu)?;
555        let mut buffer: Vec<u8> = Vec::with_capacity(max_pdu_length as usize);
556        write_pdu(&mut buffer, &msg).context(SendRequestSnafu)?;
559        socket.write_all(&buffer).context(WireSendSnafu)?;
560        buffer.clear();
561        let msg =
563            read_pdu(&mut socket, MAXIMUM_PDU_SIZE, self.strict).context(ReceiveResponseSnafu)?;
564
565        match msg {
566            Pdu::AssociationAC(AssociationAC {
567                protocol_version: protocol_version_scp,
568                application_context_name: _,
569                presentation_contexts: presentation_contexts_scp,
570                calling_ae_title: _,
571                called_ae_title: _,
572                user_variables,
573            }) => {
574                ensure!(
575                    protocol_version == protocol_version_scp,
576                    ProtocolVersionMismatchSnafu {
577                        expected: protocol_version,
578                        got: protocol_version_scp,
579                    }
580                );
581
582                let acceptor_max_pdu_length = user_variables
583                    .iter()
584                    .find_map(|item| match item {
585                        UserVariableItem::MaxLength(len) => Some(*len),
586                        _ => None,
587                    })
588                    .unwrap_or(DEFAULT_MAX_PDU);
589
590                let acceptor_max_pdu_length = if acceptor_max_pdu_length == 0 {
592                    MAXIMUM_PDU_SIZE
593                } else {
594                    acceptor_max_pdu_length
595                };
596
597                let presentation_contexts: Vec<_> = presentation_contexts_scp
598                    .into_iter()
599                    .filter(|c| c.reason == PresentationContextResultReason::Acceptance)
600                    .collect();
601                if presentation_contexts.is_empty() {
602                    let _ = write_pdu(
604                        &mut buffer,
605                        &Pdu::AbortRQ {
606                            source: AbortRQSource::ServiceUser,
607                        },
608                    );
609                    let _ = socket.write_all(&buffer);
610                    buffer.clear();
611                    return NoAcceptedPresentationContextsSnafu.fail();
612                }
613                Ok(ClientAssociation {
614                    presentation_contexts,
615                    requestor_max_pdu_length: max_pdu_length,
616                    acceptor_max_pdu_length,
617                    socket,
618                    buffer,
619                    strict,
620                })
621            }
622            Pdu::AssociationRJ(association_rj) => RejectedSnafu { association_rj }.fail(),
623            pdu @ Pdu::AbortRQ { .. }
624            | pdu @ Pdu::ReleaseRQ { .. }
625            | pdu @ Pdu::AssociationRQ { .. }
626            | pdu @ Pdu::PData { .. }
627            | pdu @ Pdu::ReleaseRP { .. } => {
628                let _ = write_pdu(
630                    &mut buffer,
631                    &Pdu::AbortRQ {
632                        source: AbortRQSource::ServiceUser,
633                    },
634                );
635                let _ = socket.write_all(&buffer);
636                UnexpectedResponseSnafu { pdu }.fail()
637            }
638            pdu @ Pdu::Unknown { .. } => {
639                let _ = write_pdu(
641                    &mut buffer,
642                    &Pdu::AbortRQ {
643                        source: AbortRQSource::ServiceUser,
644                    },
645                );
646                let _ = socket.write_all(&buffer);
647                UnknownResponseSnafu { pdu }.fail()
648            }
649        }
650    }
651
652    fn determine_user_identity<T>(
653        username: Option<T>,
654        password: Option<T>,
655        kerberos_service_ticket: Option<T>,
656        saml_assertion: Option<T>,
657        jwt: Option<T>,
658    ) -> Option<UserIdentity>
659    where
660        T: Into<Cow<'a, str>>,
661    {
662        if let Some(username) = username {
663            if let Some(password) = password {
664                return Some(UserIdentity::new(
665                    false,
666                    UserIdentityType::UsernamePassword,
667                    username.into().as_bytes().to_vec(),
668                    password.into().as_bytes().to_vec(),
669                ));
670            } else {
671                return Some(UserIdentity::new(
672                    false,
673                    UserIdentityType::Username,
674                    username.into().as_bytes().to_vec(),
675                    vec![],
676                ));
677            }
678        }
679
680        if let Some(kerberos_service_ticket) = kerberos_service_ticket {
681            return Some(UserIdentity::new(
682                false,
683                UserIdentityType::KerberosServiceTicket,
684                kerberos_service_ticket.into().as_bytes().to_vec(),
685                vec![],
686            ));
687        }
688
689        if let Some(saml_assertion) = saml_assertion {
690            return Some(UserIdentity::new(
691                false,
692                UserIdentityType::SamlAssertion,
693                saml_assertion.into().as_bytes().to_vec(),
694                vec![],
695            ));
696        }
697
698        if let Some(jwt) = jwt {
699            return Some(UserIdentity::new(
700                false,
701                UserIdentityType::Jwt,
702                jwt.into().as_bytes().to_vec(),
703                vec![],
704            ));
705        }
706
707        None
708    }
709}
710
711#[derive(Debug)]
725pub struct ClientAssociation {
726    presentation_contexts: Vec<PresentationContextResult>,
729    requestor_max_pdu_length: u32,
731    acceptor_max_pdu_length: u32,
733    socket: TcpStream,
735    buffer: Vec<u8>,
737    strict: bool,
739}
740
741impl ClientAssociation {
742    pub fn presentation_contexts(&self) -> &[PresentationContextResult] {
744        &self.presentation_contexts
745    }
746
747    pub fn acceptor_max_pdu_length(&self) -> u32 {
750        self.acceptor_max_pdu_length
751    }
752
753    pub fn requestor_max_pdu_length(&self) -> u32 {
760        self.requestor_max_pdu_length
761    }
762
763    pub fn send(&mut self, msg: &Pdu) -> Result<()> {
765        self.buffer.clear();
766        write_pdu(&mut self.buffer, msg).context(SendSnafu)?;
767        if self.buffer.len() > self.acceptor_max_pdu_length as usize {
768            return SendTooLongPduSnafu {
769                length: self.buffer.len(),
770            }
771            .fail();
772        }
773        self.socket.write_all(&self.buffer).context(WireSendSnafu)
774    }
775
776    pub fn receive(&mut self) -> Result<Pdu> {
778        read_pdu(&mut self.socket, self.requestor_max_pdu_length, self.strict).context(ReceiveSnafu)
779    }
780
781    pub fn release(mut self) -> Result<()> {
784        let out = self.release_impl();
785        let _ = self.socket.shutdown(std::net::Shutdown::Both);
786        out
787    }
788
789    pub fn abort(mut self) -> Result<()> {
792        let pdu = Pdu::AbortRQ {
793            source: AbortRQSource::ServiceUser,
794        };
795        let out = self.send(&pdu);
796        let _ = self.socket.shutdown(std::net::Shutdown::Both);
797        out
798    }
799
800    pub fn inner_stream(&mut self) -> &mut TcpStream {
810        &mut self.socket
811    }
812
813    pub fn send_pdata(&mut self, presentation_context_id: u8) -> PDataWriter<&mut TcpStream> {
819        PDataWriter::new(
820            &mut self.socket,
821            presentation_context_id,
822            self.acceptor_max_pdu_length,
823        )
824    }
825
826    pub fn receive_pdata(&mut self) -> PDataReader<&mut TcpStream> {
832        PDataReader::new(&mut self.socket, self.requestor_max_pdu_length)
833    }
834
835    fn release_impl(&mut self) -> Result<()> {
841        let pdu = Pdu::ReleaseRQ;
842        self.send(&pdu)?;
843        let pdu = read_pdu(&mut self.socket, self.requestor_max_pdu_length, self.strict)
844            .context(ReceiveSnafu)?;
845
846        match pdu {
847            Pdu::ReleaseRP => {}
848            pdu @ Pdu::AbortRQ { .. }
849            | pdu @ Pdu::AssociationAC { .. }
850            | pdu @ Pdu::AssociationRJ { .. }
851            | pdu @ Pdu::AssociationRQ { .. }
852            | pdu @ Pdu::PData { .. }
853            | pdu @ Pdu::ReleaseRQ { .. } => return UnexpectedResponseSnafu { pdu }.fail(),
854            pdu @ Pdu::Unknown { .. } => return UnknownResponseSnafu { pdu }.fail(),
855        }
856        Ok(())
857    }
858}
859
860impl Drop for ClientAssociation {
862    fn drop(&mut self) {
863        let _ = self.release_impl();
864        let _ = self.socket.shutdown(std::net::Shutdown::Both);
865    }
866}