use crate::control_socket_address::ControlSocket::{TcpSocket, UnixLocalAbstractSocket, UnixLocalSocket}; use crate::control_socket_address::ControlSocketScheme::{HTTP, HTTPS}; use crate::unit_client::UnitClientError; use hyper::http::uri::{Authority, PathAndQuery}; use hyper::Uri; use std::fmt::{Display, Formatter}; use std::fs; use std::os::unix::fs::FileTypeExt; use std::path::{PathBuf, MAIN_SEPARATOR}; type AbstractSocketName = String; type UnixSocketPath = PathBuf; type Port = u16; #[derive(Debug, Clone)] pub enum ControlSocket { UnixLocalAbstractSocket(AbstractSocketName), UnixLocalSocket(UnixSocketPath), TcpSocket(Uri), } #[derive(Debug)] pub enum ControlSocketScheme { HTTP, HTTPS, } impl ControlSocketScheme { fn port(&self) -> Port { match self { HTTP => 80, HTTPS => 443, } } } impl ControlSocket { pub fn socket_scheme(&self) -> ControlSocketScheme { match self { UnixLocalAbstractSocket(_) => ControlSocketScheme::HTTP, UnixLocalSocket(_) => ControlSocketScheme::HTTP, TcpSocket(uri) => match uri.scheme_str().expect("Scheme should not be None") { "http" => ControlSocketScheme::HTTP, "https" => ControlSocketScheme::HTTPS, _ => unreachable!("Scheme should be http or https"), }, } } pub fn create_uri_with_path(&self, str_path: &str) -> Uri { match self { UnixLocalAbstractSocket(name) => { let socket_path = PathBuf::from(format!("@{}", name)); hyperlocal::Uri::new(socket_path, str_path).into() } UnixLocalSocket(socket_path) => hyperlocal::Uri::new(socket_path, str_path).into(), TcpSocket(uri) => { if str_path.is_empty() { uri.clone() } else { let authority = uri.authority().expect("Authority should not be None"); Uri::builder() .scheme(uri.scheme_str().expect("Scheme should not be None")) .authority(authority.clone()) .path_and_query(str_path) .build() .expect("URI should be valid") } } } } } impl Display for ControlSocket { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { UnixLocalAbstractSocket(name) => f.write_fmt(format_args!("unix:@{}", name)), UnixLocalSocket(path) => f.write_fmt(format_args!("unix:{}", path.to_string_lossy())), TcpSocket(uri) => uri.fmt(f), } } } impl From for String { fn from(val: ControlSocket) -> Self { val.to_string() } } impl From for PathBuf { fn from(val: ControlSocket) -> Self { match val { UnixLocalAbstractSocket(socket_name) => PathBuf::from(format!("@{}", socket_name)), UnixLocalSocket(socket_path) => socket_path, TcpSocket(_) => PathBuf::default(), } } } impl From for Uri { fn from(val: ControlSocket) -> Self { val.create_uri_with_path("") } } impl ControlSocket { pub fn validate_http_address(uri: Uri) -> Result<(), UnitClientError> { let http_address = uri.to_string(); if uri.authority().is_none() { return Err(UnitClientError::TcpSocketAddressParseError { message: "No authority found in socket address".to_string(), control_socket_address: http_address, }); } if uri.port_u16().is_none() { return Err(UnitClientError::TcpSocketAddressNoPortError { control_socket_address: http_address, }); } if !(uri.path().is_empty() || uri.path().eq("/")) { return Err(UnitClientError::TcpSocketAddressParseError { message: format!("Path is not empty or is not / [path={}]", uri.path()), control_socket_address: http_address, }); } Ok(()) } pub fn validate_unix_address(socket: PathBuf) -> Result<(), UnitClientError> { if !socket.exists() { return Err(UnitClientError::UnixSocketNotFound { control_socket_address: socket.to_string_lossy().to_string(), }); } let metadata = fs::metadata(&socket).map_err(|error| UnitClientError::UnixSocketAddressError { source: error, control_socket_address: socket.to_string_lossy().to_string(), })?; let file_type = metadata.file_type(); if !file_type.is_socket() { return Err(UnitClientError::UnixSocketAddressError { source: std::io::Error::new(std::io::ErrorKind::Other, "Control socket path is not a socket"), control_socket_address: socket.to_string_lossy().to_string(), }); } Ok(()) } pub fn validate(&self) -> Result { match self { UnixLocalAbstractSocket(socket_name) => { let socket_path = PathBuf::from(format!("@{}", socket_name)); Self::validate_unix_address(socket_path.clone()) } UnixLocalSocket(socket_path) => Self::validate_unix_address(socket_path.clone()), TcpSocket(socket_uri) => Self::validate_http_address(socket_uri.clone()), } .map(|_| self.to_owned()) } fn normalize_and_parse_http_address(http_address: String) -> Result { // Convert *:1 style network addresses to URI format let address = if http_address.starts_with("*:") { http_address.replacen("*:", "http://127.0.0.1:", 1) // Add scheme if not present } else if !(http_address.starts_with("http://") || http_address.starts_with("https://")) { format!("http://{}", http_address) } else { http_address.to_owned() }; let is_https = address.starts_with("https://"); let parsed_uri = Uri::try_from(address.as_str()).map_err(|error| UnitClientError::TcpSocketAddressUriError { source: error, control_socket_address: address, })?; let authority = parsed_uri.authority().expect("Authority should not be None"); let expected_port = if is_https { HTTPS.port() } else { HTTP.port() }; let normalized_authority = match authority.port_u16() { Some(_) => authority.to_owned(), None => { let host = format!("{}:{}", authority.host(), expected_port); Authority::try_from(host.as_str()).expect("Authority should be valid") } }; let normalized_uri = Uri::builder() .scheme(parsed_uri.scheme_str().expect("Scheme should not be None")) .authority(normalized_authority) .path_and_query(PathAndQuery::from_static("")) .build() .map_err(|error| UnitClientError::TcpSocketAddressParseError { message: error.to_string(), control_socket_address: http_address.clone(), })?; Ok(normalized_uri) } /// Flexibly parse a textual representation of a socket address fn parse_address>(socket_address: S) -> Result { let full_socket_address: String = socket_address.into(); let socket_prefix = "unix:"; let socket_uri_prefix = "unix://"; let mut buf = String::with_capacity(socket_prefix.len()); for (i, c) in full_socket_address.char_indices() { // Abstract unix socket with no prefix if i == 0 && c == '@' { return Ok(UnixLocalAbstractSocket(full_socket_address[1..].to_string())); } buf.push(c); // Unix socket with prefix if i == socket_prefix.len() - 1 && buf.eq(socket_prefix) { let path_text = full_socket_address[socket_prefix.len()..].to_string(); // Return here if this URI does not have a scheme followed by double slashes if !path_text.starts_with("//") { return match path_text.strip_prefix('@') { Some(name) => Ok(UnixLocalAbstractSocket(name.to_string())), None => { let path = PathBuf::from(path_text); Ok(UnixLocalSocket(path)) } }; } } // Unix socket with URI prefix if i == socket_uri_prefix.len() - 1 && buf.eq(socket_uri_prefix) { let uri = Uri::try_from(full_socket_address.as_str()).map_err(|error| { UnitClientError::TcpSocketAddressParseError { message: error.to_string(), control_socket_address: full_socket_address.clone(), } })?; return ControlSocket::try_from(uri); } } /* Sockets on Windows are not supported, so there is no need to check * if the socket address is a valid path, so we can do this shortcut * here to see if a path was specified without a unix: prefix. */ if buf.starts_with(MAIN_SEPARATOR) { let path = PathBuf::from(buf); return Ok(UnixLocalSocket(path)); } let uri = Self::normalize_and_parse_http_address(buf)?; Ok(TcpSocket(uri)) } pub fn is_local_socket(&self) -> bool { match self { UnixLocalAbstractSocket(_) | UnixLocalSocket(_) => true, TcpSocket(_) => false, } } } impl TryFrom for ControlSocket { type Error = UnitClientError; fn try_from(socket_address: String) -> Result { ControlSocket::parse_address(socket_address.as_str()) } } impl TryFrom<&str> for ControlSocket { type Error = UnitClientError; fn try_from(socket_address: &str) -> Result { ControlSocket::parse_address(socket_address) } } impl TryFrom for ControlSocket { type Error = UnitClientError; fn try_from(socket_uri: Uri) -> Result { match socket_uri.scheme_str() { // URIs with the unix scheme will have a hostname that is a hex encoded string // representing the path to the socket Some("unix") => { let host = match socket_uri.host() { Some(host) => host, None => { return Err(UnitClientError::TcpSocketAddressParseError { message: "No host found in socket address".to_string(), control_socket_address: socket_uri.to_string(), }) } }; let bytes = hex::decode(host).map_err(|error| UnitClientError::TcpSocketAddressParseError { message: error.to_string(), control_socket_address: socket_uri.to_string(), })?; let path = String::from_utf8_lossy(&bytes); ControlSocket::parse_address(path) } Some("http") | Some("https") => Ok(TcpSocket(socket_uri)), Some(unknown) => Err(UnitClientError::TcpSocketAddressParseError { message: format!("Unsupported scheme found in socket address: {}", unknown).to_string(), control_socket_address: socket_uri.to_string(), }), None => Err(UnitClientError::TcpSocketAddressParseError { message: "No scheme found in socket address".to_string(), control_socket_address: socket_uri.to_string(), }), } } } #[cfg(test)] mod tests { use rand::distributions::{Alphanumeric, DistString}; use std::env::temp_dir; use std::fmt::Display; use std::io; use std::os::unix::net::UnixListener; use super::*; struct TempSocket { socket_path: PathBuf, _listener: UnixListener, } impl TempSocket { fn shutdown(&mut self) -> io::Result<()> { fs::remove_file(&self.socket_path) } } impl Display for TempSocket { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "unix:{}", self.socket_path.to_string_lossy().to_string()) } } impl Drop for TempSocket { fn drop(&mut self) { self.shutdown() .expect(format!("Unable to shutdown socket {}", self.socket_path.to_string_lossy()).as_str()); } } #[test] fn will_error_with_nonexistent_unix_socket() { let socket_address = "unix:/tmp/some_random_filename_that_doesnt_exist.sock"; let control_socket = ControlSocket::try_from(socket_address).expect("No error should be returned until validate() is called"); assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); assert!(control_socket.validate().is_err(), "Socket should not be valid"); } #[test] fn can_parse_socket_with_prefix() { let temp_socket = create_file_socket().expect("Unable to create socket"); let control_socket = ControlSocket::try_from(temp_socket.to_string()).expect("Error parsing good socket path"); assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); if let Err(e) = control_socket.validate() { panic!("Socket should be valid: {}", e); } } #[test] fn can_parse_socket_from_uri() { let temp_socket = create_file_socket().expect("Unable to create socket"); let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into(); let control_socket = ControlSocket::try_from(uri).expect("Error parsing good socket path"); assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); if let Err(e) = control_socket.validate() { panic!("Socket should be valid: {}", e); } } #[test] fn can_parse_socket_from_uri_text() { let temp_socket = create_file_socket().expect("Unable to create socket"); let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into(); let control_socket = ControlSocket::parse_address(uri.to_string()).expect("Error parsing good socket path"); assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); if let Err(e) = control_socket.validate() { panic!("Socket for input text should be valid: {}", e); } } #[test] #[cfg(target_os = "linux")] fn can_parse_abstract_socket_from_uri() { let temp_socket = create_abstract_socket().expect("Unable to create socket"); let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into(); let control_socket = ControlSocket::try_from(uri).expect("Error parsing good socket path"); assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); if let Err(e) = control_socket.validate() { panic!("Socket should be valid: {}", e); } } #[test] #[cfg(target_os = "linux")] fn can_parse_abstract_socket_from_uri_text() { let temp_socket = create_abstract_socket().expect("Unable to create socket"); let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into(); let control_socket = ControlSocket::parse_address(uri.to_string()).expect("Error parsing good socket path"); assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); if let Err(e) = control_socket.validate() { panic!("Socket should be valid: {}", e); } } #[test] fn can_parse_socket_without_prefix() { let temp_socket = create_file_socket().expect("Unable to create socket"); let control_socket = ControlSocket::try_from(temp_socket.socket_path.to_string_lossy().to_string()) .expect("Error parsing good socket path"); assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); if let Err(e) = control_socket.validate() { panic!("Socket should be valid: {}", e); } } #[cfg(target_os = "linux")] #[test] fn can_parse_abstract_socket() { let temp_socket = create_abstract_socket().expect("Unable to create socket"); let control_socket = ControlSocket::try_from(temp_socket.to_string()).expect("Error parsing good socket path"); assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); if let Err(e) = control_socket.validate() { panic!("Socket should be valid: {}", e); } } #[test] fn can_normalize_good_http_socket_addresses() { let valid_socket_addresses = vec![ "http://127.0.0.1:8080", "https://127.0.0.1:8080", "http://127.0.0.1:8080/", "127.0.0.1:8080", "http://0.0.0.0:8080", "https://0.0.0.0:8080", "http://0.0.0.0:8080/", "0.0.0.0:8080", "http://localhost:8080", "https://localhost:8080", "http://localhost:8080/", "localhost:8080", "http://[::1]:8080", "https://[::1]:8080", "http://[::1]:8080/", "[::1]:8080", "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080", "https://[0000:0000:0000:0000:0000:0000:0000:0000]:8080", "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080/", "[0000:0000:0000:0000:0000:0000:0000:0000]:8080", ]; for socket_address in valid_socket_addresses { let mut expected = if socket_address.starts_with("http") { socket_address.to_string().trim_end_matches('/').to_string() } else { format!("http://{}", socket_address).trim_end_matches('/').to_string() }; expected.push('/'); let control_socket = ControlSocket::try_from(socket_address).expect("Error parsing good socket path"); assert!(!control_socket.is_local_socket(), "Not parsed as a local socket"); if let Err(e) = control_socket.validate() { panic!("Socket should be valid: {}", e); } } } #[test] fn can_normalize_wildcard_http_socket_address() { let socket_address = "*:8080"; let expected = "http://127.0.0.1:8080/"; let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string()); let normalized = normalized_result .expect("Unable to normalize socket address") .to_string(); assert_eq!(normalized, expected); } #[test] fn can_normalize_http_socket_address_with_no_port() { let socket_address = "http://localhost"; let expected = "http://localhost:80/"; let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string()); let normalized = normalized_result .expect("Unable to normalize socket address") .to_string(); assert_eq!(normalized, expected); } #[test] fn can_normalize_https_socket_address_with_no_port() { let socket_address = "https://localhost"; let expected = "https://localhost:443/"; let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string()); let normalized = normalized_result .expect("Unable to normalize socket address") .to_string(); assert_eq!(normalized, expected); } #[test] fn can_parse_http_addresses() { let valid_socket_addresses = vec![ "http://127.0.0.1:8080", "https://127.0.0.1:8080", "http://127.0.0.1:8080/", "127.0.0.1:8080", "http://0.0.0.0:8080", "https://0.0.0.0:8080", "http://0.0.0.0:8080/", "0.0.0.0:8080", "http://localhost:8080", "https://localhost:8080", "http://localhost:8080/", "localhost:8080", "http://[::1]:8080", "https://[::1]:8080", "http://[::1]:8080/", "[::1]:8080", "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080", "https://[0000:0000:0000:0000:0000:0000:0000:0000]:8080", "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080/", "[0000:0000:0000:0000:0000:0000:0000:0000]:8080", ]; for socket_address in valid_socket_addresses { let mut expected = if socket_address.starts_with("http") { socket_address.to_string().trim_end_matches('/').to_string() } else { format!("http://{}", socket_address).trim_end_matches('/').to_string() }; expected.push('/'); let normalized = ControlSocket::normalize_and_parse_http_address(socket_address.to_string()) .expect("Unable to normalize socket address") .to_string(); assert_eq!(normalized, expected); } } fn create_file_socket() -> Result { let random = Alphanumeric.sample_string(&mut rand::thread_rng(), 10); let socket_name = format!("unit-client-socket-test-{}.sock", random); let socket_path = temp_dir().join(socket_name); let listener = UnixListener::bind(&socket_path)?; Ok(TempSocket { socket_path, _listener: listener, }) } #[cfg(target_os = "linux")] fn create_abstract_socket() -> Result { let random = Alphanumeric.sample_string(&mut rand::thread_rng(), 10); let socket_name = format!("@unit-client-socket-test-{}.sock", random); let socket_path = PathBuf::from(socket_name); let listener = UnixListener::bind(&socket_path)?; Ok(TempSocket { socket_path, _listener: listener, }) } }