1use std::{
8 borrow::Cow,
9 convert::TryInto,
10 io::Write,
11 net::{TcpStream, ToSocketAddrs}, time::Duration,
12};
13
14use crate::{
15 pdu::{
16 reader::{read_pdu, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE},
17 writer::write_pdu,
18 AbortRQSource, AssociationAC, AssociationRJ, AssociationRQ, Pdu,
19 PresentationContextProposed, PresentationContextResult, PresentationContextResultReason,
20 UserIdentity, UserIdentityType, UserVariableItem,
21 },
22 AeAddr, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME,
23};
24use snafu::{ensure, Backtrace, ResultExt, Snafu};
25
26use super::{
27 pdata::{PDataReader, PDataWriter},
28 uid::trim_uid,
29};
30
31#[derive(Debug, Snafu)]
32#[non_exhaustive]
33pub enum Error {
34 MissingAbstractSyntax { backtrace: Backtrace },
36
37 Connect {
39 source: std::io::Error,
40 backtrace: Backtrace,
41 },
42
43 SetReadTimeout{
45 source: std::io::Error,
46 backtrace: Backtrace,
47 },
48
49 SetWriteTimeout{
51 source: std::io::Error,
52 backtrace: Backtrace,
53 },
54
55 SendRequest {
57 #[snafu(backtrace)]
58 source: crate::pdu::writer::Error,
59 },
60
61 ReceiveResponse {
63 #[snafu(backtrace)]
64 source: crate::pdu::reader::Error,
65 },
66
67 #[snafu(display("unexpected response from server `{:?}`", pdu))]
68 #[non_exhaustive]
69 UnexpectedResponse {
70 pdu: Box<Pdu>,
72 },
73
74 #[snafu(display("unknown response from server `{:?}`", pdu))]
75 #[non_exhaustive]
76 UnknownResponse {
77 pdu: Box<Pdu>,
79 },
80
81 #[snafu(display("protocol version mismatch: expected {}, got {}", expected, got))]
82 ProtocolVersionMismatch {
83 expected: u16,
84 got: u16,
85 backtrace: Backtrace,
86 },
87
88 #[snafu(display("association rejected by the server: {}", association_rj.source))]
89 Rejected {
90 association_rj: AssociationRJ,
91 backtrace: Backtrace,
92 },
93
94 NoAcceptedPresentationContexts { backtrace: Backtrace },
96
97 #[non_exhaustive]
99 Send {
100 #[snafu(backtrace)]
101 source: crate::pdu::writer::Error,
102 },
103
104 #[non_exhaustive]
106 WireSend {
107 source: std::io::Error,
108 backtrace: Backtrace,
109 },
110
111 #[snafu(display(
112 "PDU is too large ({} bytes) to be sent to the remote application entity",
113 length
114 ))]
115 #[non_exhaustive]
116 SendTooLongPdu { length: usize, backtrace: Backtrace },
117
118 #[non_exhaustive]
120 Receive {
121 #[snafu(backtrace)]
122 source: crate::pdu::reader::Error,
123 },
124}
125
126pub type Result<T, E = Error> = std::result::Result<T, E>;
127
128#[derive(Debug, Clone)]
167pub struct ClientAssociationOptions<'a> {
168 calling_ae_title: Cow<'a, str>,
170 called_ae_title: Option<Cow<'a, str>>,
172 application_context_name: Cow<'a, str>,
174 presentation_contexts: Vec<(Cow<'a, str>, Vec<Cow<'a, str>>)>,
176 protocol_version: u16,
178 max_pdu_length: u32,
180 strict: bool,
182 username: Option<Cow<'a, str>>,
184 password: Option<Cow<'a, str>>,
186 kerberos_service_ticket: Option<Cow<'a, str>>,
188 saml_assertion: Option<Cow<'a, str>>,
190 jwt: Option<Cow<'a, str>>,
192 read_timeout: Option<Duration>,
194 write_timeout: Option<Duration>,
196}
197
198impl<'a> Default for ClientAssociationOptions<'a> {
199 fn default() -> Self {
200 ClientAssociationOptions {
201 calling_ae_title: "THIS-SCU".into(),
203 called_ae_title: None,
205 application_context_name: "1.2.840.10008.3.1.1.1".into(),
207 presentation_contexts: Vec::new(),
209 protocol_version: 1,
210 max_pdu_length: crate::pdu::reader::DEFAULT_MAX_PDU,
211 strict: true,
212 username: None,
213 password: None,
214 kerberos_service_ticket: None,
215 saml_assertion: None,
216 jwt: None,
217 read_timeout: None,
218 write_timeout: None,
219 }
220 }
221}
222
223impl<'a> ClientAssociationOptions<'a> {
224 pub fn new() -> Self {
226 Self::default()
227 }
228
229 pub fn calling_ae_title<T>(mut self, calling_ae_title: T) -> Self
234 where
235 T: Into<Cow<'a, str>>,
236 {
237 self.calling_ae_title = calling_ae_title.into();
238 self
239 }
240
241 pub fn called_ae_title<T>(mut self, called_ae_title: T) -> Self
248 where
249 T: Into<Cow<'a, str>>,
250 {
251 let cae = called_ae_title.into();
252 if cae.is_empty() {
253 self.called_ae_title = None;
254 } else {
255 self.called_ae_title = Some(cae);
256 }
257 self
258 }
259
260 pub fn with_presentation_context<T>(
263 mut self,
264 abstract_syntax_uid: T,
265 transfer_syntax_uids: Vec<T>,
266 ) -> Self
267 where
268 T: Into<Cow<'a, str>>,
269 {
270 let transfer_syntaxes: Vec<Cow<'a, str>> = transfer_syntax_uids
271 .into_iter()
272 .map(|t| trim_uid(t.into()))
273 .collect();
274 self.presentation_contexts
275 .push((trim_uid(abstract_syntax_uid.into()), transfer_syntaxes));
276 self
277 }
278
279 pub fn with_abstract_syntax<T>(self, abstract_syntax_uid: T) -> Self
283 where
284 T: Into<Cow<'a, str>>,
285 {
286 let default_transfer_syntaxes: Vec<Cow<'a, str>> =
287 vec!["1.2.840.10008.1.2.1".into(), "1.2.840.10008.1.2".into()];
288 self.with_presentation_context(abstract_syntax_uid.into(), default_transfer_syntaxes)
289 }
290
291 pub fn max_pdu_length(mut self, value: u32) -> Self {
294 self.max_pdu_length = value;
295 self
296 }
297
298 pub fn strict(mut self, strict: bool) -> Self {
302 self.strict = strict;
303 self
304 }
305
306 pub fn username<T>(mut self, username: T) -> Self
308 where
309 T: Into<Cow<'a, str>>,
310 {
311 let username = username.into();
312 if username.is_empty() {
313 self.username = None;
314 } else {
315 self.username = Some(username);
316 self.saml_assertion = None;
317 self.jwt = None;
318 self.kerberos_service_ticket = None;
319 }
320 self
321 }
322
323 pub fn password<T>(mut self, password: T) -> Self
325 where
326 T: Into<Cow<'a, str>>,
327 {
328 let password = password.into();
329 if password.is_empty() {
330 self.password = None;
331 } else {
332 self.password = Some(password);
333 self.saml_assertion = None;
334 self.jwt = None;
335 self.kerberos_service_ticket = None;
336 }
337 self
338 }
339
340 pub fn username_password<T, U>(mut self, username: T, password: U) -> Self
342 where
343 T: Into<Cow<'a, str>>,
344 U: Into<Cow<'a, str>>,
345 {
346 let username = username.into();
347 let password = password.into();
348 if username.is_empty() {
349 self.username = None;
350 self.password = None;
351 } else {
352 self.username = Some(username);
353 self.password = Some(password);
354 self.saml_assertion = None;
355 self.jwt = None;
356 self.kerberos_service_ticket = None;
357 }
358 self
359 }
360
361 pub fn kerberos_service_ticket<T>(mut self, kerberos_service_ticket: T) -> Self
363 where
364 T: Into<Cow<'a, str>>,
365 {
366 let kerberos_service_ticket = kerberos_service_ticket.into();
367 if kerberos_service_ticket.is_empty() {
368 self.kerberos_service_ticket = None;
369 } else {
370 self.kerberos_service_ticket = Some(kerberos_service_ticket);
371 self.username = None;
372 self.password = None;
373 self.saml_assertion = None;
374 self.jwt = None;
375 }
376 self
377 }
378
379 pub fn saml_assertion<T>(mut self, saml_assertion: T) -> Self
381 where
382 T: Into<Cow<'a, str>>,
383 {
384 let saml_assertion = saml_assertion.into();
385 if saml_assertion.is_empty() {
386 self.saml_assertion = None;
387 } else {
388 self.saml_assertion = Some(saml_assertion);
389 self.username = None;
390 self.password = None;
391 self.jwt = None;
392 self.kerberos_service_ticket = None;
393 }
394 self
395 }
396
397 pub fn jwt<T>(mut self, jwt: T) -> Self
399 where
400 T: Into<Cow<'a, str>>,
401 {
402 let jwt = jwt.into();
403 if jwt.is_empty() {
404 self.jwt = None;
405 } else {
406 self.jwt = Some(jwt);
407 self.username = None;
408 self.password = None;
409 self.saml_assertion = None;
410 self.kerberos_service_ticket = None;
411 }
412 self
413 }
414
415 pub fn establish<A: ToSocketAddrs>(self, address: A) -> Result<ClientAssociation> {
419 self.establish_impl(AeAddr::new_socket_addr(address))
420 }
421
422 pub fn establish_with(self, ae_address: &str) -> Result<ClientAssociation> {
446 match ae_address.try_into() {
447 Ok(ae_address) => self.establish_impl(ae_address),
448 Err(_) => self.establish_impl(AeAddr::new_socket_addr(ae_address)),
449 }
450 }
451
452 pub fn read_timeout(self, timeout: Duration) -> Self {
454 Self {
455 read_timeout: Some(timeout),
456 ..self
457 }
458 }
459
460 pub fn write_timeout(self, timeout: Duration) -> Self {
462 Self {
463 write_timeout: Some(timeout),
464 ..self
465 }
466 }
467
468 fn establish_impl<T>(self, ae_address: AeAddr<T>) -> Result<ClientAssociation>
469 where
470 T: ToSocketAddrs,
471 {
472 let ClientAssociationOptions {
473 calling_ae_title,
474 called_ae_title,
475 application_context_name,
476 presentation_contexts,
477 protocol_version,
478 max_pdu_length,
479 strict,
480 username,
481 password,
482 kerberos_service_ticket,
483 saml_assertion,
484 jwt,
485 read_timeout,
486 write_timeout
487 } = self;
488
489 ensure!(
492 !presentation_contexts.is_empty(),
493 MissingAbstractSyntaxSnafu
494 );
495
496 let called_ae_title: &str = match (&called_ae_title, ae_address.ae_title()) {
498 (Some(aec), Some(_)) => {
499 tracing::warn!(
500 "Option `called_ae_title` overrides the AE title to `{}`",
501 aec
502 );
503 aec
504 }
505 (Some(aec), None) => aec,
506 (None, Some(aec)) => aec,
507 (None, None) => "ANY-SCP",
508 };
509
510 let presentation_contexts: Vec<_> = presentation_contexts
511 .into_iter()
512 .enumerate()
513 .map(|(i, presentation_context)| PresentationContextProposed {
514 id: (2 * i + 1) as u8,
515 abstract_syntax: presentation_context.0.to_string(),
516 transfer_syntaxes: presentation_context
517 .1
518 .iter()
519 .map(|uid| uid.to_string())
520 .collect(),
521 })
522 .collect();
523
524 let mut user_variables = vec![
525 UserVariableItem::MaxLength(max_pdu_length),
526 UserVariableItem::ImplementationClassUID(IMPLEMENTATION_CLASS_UID.to_string()),
527 UserVariableItem::ImplementationVersionName(IMPLEMENTATION_VERSION_NAME.to_string()),
528 ];
529
530 if let Some(user_identity) = Self::determine_user_identity(
531 username,
532 password,
533 kerberos_service_ticket,
534 saml_assertion,
535 jwt,
536 ) {
537 user_variables.push(UserVariableItem::UserIdentityItem(user_identity));
538 }
539
540 let msg = Pdu::AssociationRQ(AssociationRQ {
541 protocol_version,
542 calling_ae_title: calling_ae_title.to_string(),
543 called_ae_title: called_ae_title.to_string(),
544 application_context_name: application_context_name.to_string(),
545 presentation_contexts,
546 user_variables,
547 });
548
549 let mut socket = std::net::TcpStream::connect(ae_address)
550 .context(ConnectSnafu)?;
551 socket.set_read_timeout(read_timeout)
552 .context(SetReadTimeoutSnafu)?;
553 socket.set_write_timeout(write_timeout)
554 .context(SetWriteTimeoutSnafu)?;
555 let mut buffer: Vec<u8> = Vec::with_capacity(max_pdu_length as usize);
556 write_pdu(&mut buffer, &msg).context(SendRequestSnafu)?;
559 socket.write_all(&buffer).context(WireSendSnafu)?;
560 buffer.clear();
561 let msg =
563 read_pdu(&mut socket, MAXIMUM_PDU_SIZE, self.strict).context(ReceiveResponseSnafu)?;
564
565 match msg {
566 Pdu::AssociationAC(AssociationAC {
567 protocol_version: protocol_version_scp,
568 application_context_name: _,
569 presentation_contexts: presentation_contexts_scp,
570 calling_ae_title: _,
571 called_ae_title: _,
572 user_variables,
573 }) => {
574 ensure!(
575 protocol_version == protocol_version_scp,
576 ProtocolVersionMismatchSnafu {
577 expected: protocol_version,
578 got: protocol_version_scp,
579 }
580 );
581
582 let acceptor_max_pdu_length = user_variables
583 .iter()
584 .find_map(|item| match item {
585 UserVariableItem::MaxLength(len) => Some(*len),
586 _ => None,
587 })
588 .unwrap_or(DEFAULT_MAX_PDU);
589
590 let acceptor_max_pdu_length = if acceptor_max_pdu_length == 0 {
592 MAXIMUM_PDU_SIZE
593 } else {
594 acceptor_max_pdu_length
595 };
596
597 let presentation_contexts: Vec<_> = presentation_contexts_scp
598 .into_iter()
599 .filter(|c| c.reason == PresentationContextResultReason::Acceptance)
600 .collect();
601 if presentation_contexts.is_empty() {
602 let _ = write_pdu(
604 &mut buffer,
605 &Pdu::AbortRQ {
606 source: AbortRQSource::ServiceUser,
607 },
608 );
609 let _ = socket.write_all(&buffer);
610 buffer.clear();
611 return NoAcceptedPresentationContextsSnafu.fail();
612 }
613 Ok(ClientAssociation {
614 presentation_contexts,
615 requestor_max_pdu_length: max_pdu_length,
616 acceptor_max_pdu_length,
617 socket,
618 buffer,
619 strict,
620 })
621 }
622 Pdu::AssociationRJ(association_rj) => RejectedSnafu { association_rj }.fail(),
623 pdu @ Pdu::AbortRQ { .. }
624 | pdu @ Pdu::ReleaseRQ { .. }
625 | pdu @ Pdu::AssociationRQ { .. }
626 | pdu @ Pdu::PData { .. }
627 | pdu @ Pdu::ReleaseRP { .. } => {
628 let _ = write_pdu(
630 &mut buffer,
631 &Pdu::AbortRQ {
632 source: AbortRQSource::ServiceUser,
633 },
634 );
635 let _ = socket.write_all(&buffer);
636 UnexpectedResponseSnafu { pdu }.fail()
637 }
638 pdu @ Pdu::Unknown { .. } => {
639 let _ = write_pdu(
641 &mut buffer,
642 &Pdu::AbortRQ {
643 source: AbortRQSource::ServiceUser,
644 },
645 );
646 let _ = socket.write_all(&buffer);
647 UnknownResponseSnafu { pdu }.fail()
648 }
649 }
650 }
651
652 fn determine_user_identity<T>(
653 username: Option<T>,
654 password: Option<T>,
655 kerberos_service_ticket: Option<T>,
656 saml_assertion: Option<T>,
657 jwt: Option<T>,
658 ) -> Option<UserIdentity>
659 where
660 T: Into<Cow<'a, str>>,
661 {
662 if let Some(username) = username {
663 if let Some(password) = password {
664 return Some(UserIdentity::new(
665 false,
666 UserIdentityType::UsernamePassword,
667 username.into().as_bytes().to_vec(),
668 password.into().as_bytes().to_vec(),
669 ));
670 } else {
671 return Some(UserIdentity::new(
672 false,
673 UserIdentityType::Username,
674 username.into().as_bytes().to_vec(),
675 vec![],
676 ));
677 }
678 }
679
680 if let Some(kerberos_service_ticket) = kerberos_service_ticket {
681 return Some(UserIdentity::new(
682 false,
683 UserIdentityType::KerberosServiceTicket,
684 kerberos_service_ticket.into().as_bytes().to_vec(),
685 vec![],
686 ));
687 }
688
689 if let Some(saml_assertion) = saml_assertion {
690 return Some(UserIdentity::new(
691 false,
692 UserIdentityType::SamlAssertion,
693 saml_assertion.into().as_bytes().to_vec(),
694 vec![],
695 ));
696 }
697
698 if let Some(jwt) = jwt {
699 return Some(UserIdentity::new(
700 false,
701 UserIdentityType::Jwt,
702 jwt.into().as_bytes().to_vec(),
703 vec![],
704 ));
705 }
706
707 None
708 }
709}
710
711#[derive(Debug)]
725pub struct ClientAssociation {
726 presentation_contexts: Vec<PresentationContextResult>,
729 requestor_max_pdu_length: u32,
731 acceptor_max_pdu_length: u32,
733 socket: TcpStream,
735 buffer: Vec<u8>,
737 strict: bool,
739}
740
741impl ClientAssociation {
742 pub fn presentation_contexts(&self) -> &[PresentationContextResult] {
744 &self.presentation_contexts
745 }
746
747 pub fn acceptor_max_pdu_length(&self) -> u32 {
750 self.acceptor_max_pdu_length
751 }
752
753 pub fn requestor_max_pdu_length(&self) -> u32 {
760 self.requestor_max_pdu_length
761 }
762
763 pub fn send(&mut self, msg: &Pdu) -> Result<()> {
765 self.buffer.clear();
766 write_pdu(&mut self.buffer, msg).context(SendSnafu)?;
767 if self.buffer.len() > self.acceptor_max_pdu_length as usize {
768 return SendTooLongPduSnafu {
769 length: self.buffer.len(),
770 }
771 .fail();
772 }
773 self.socket.write_all(&self.buffer).context(WireSendSnafu)
774 }
775
776 pub fn receive(&mut self) -> Result<Pdu> {
778 read_pdu(&mut self.socket, self.requestor_max_pdu_length, self.strict).context(ReceiveSnafu)
779 }
780
781 pub fn release(mut self) -> Result<()> {
784 let out = self.release_impl();
785 let _ = self.socket.shutdown(std::net::Shutdown::Both);
786 out
787 }
788
789 pub fn abort(mut self) -> Result<()> {
792 let pdu = Pdu::AbortRQ {
793 source: AbortRQSource::ServiceUser,
794 };
795 let out = self.send(&pdu);
796 let _ = self.socket.shutdown(std::net::Shutdown::Both);
797 out
798 }
799
800 pub fn inner_stream(&mut self) -> &mut TcpStream {
810 &mut self.socket
811 }
812
813 pub fn send_pdata(&mut self, presentation_context_id: u8) -> PDataWriter<&mut TcpStream> {
819 PDataWriter::new(
820 &mut self.socket,
821 presentation_context_id,
822 self.acceptor_max_pdu_length,
823 )
824 }
825
826 pub fn receive_pdata(&mut self) -> PDataReader<&mut TcpStream> {
832 PDataReader::new(&mut self.socket, self.requestor_max_pdu_length)
833 }
834
835 fn release_impl(&mut self) -> Result<()> {
841 let pdu = Pdu::ReleaseRQ;
842 self.send(&pdu)?;
843 let pdu = read_pdu(&mut self.socket, self.requestor_max_pdu_length, self.strict)
844 .context(ReceiveSnafu)?;
845
846 match pdu {
847 Pdu::ReleaseRP => {}
848 pdu @ Pdu::AbortRQ { .. }
849 | pdu @ Pdu::AssociationAC { .. }
850 | pdu @ Pdu::AssociationRJ { .. }
851 | pdu @ Pdu::AssociationRQ { .. }
852 | pdu @ Pdu::PData { .. }
853 | pdu @ Pdu::ReleaseRQ { .. } => return UnexpectedResponseSnafu { pdu }.fail(),
854 pdu @ Pdu::Unknown { .. } => return UnknownResponseSnafu { pdu }.fail(),
855 }
856 Ok(())
857 }
858}
859
860impl Drop for ClientAssociation {
862 fn drop(&mut self) {
863 let _ = self.release_impl();
864 let _ = self.socket.shutdown(std::net::Shutdown::Both);
865 }
866}