diff options
Diffstat (limited to 'tools/unitctl/unit-client-rs/src/control_socket_address.rs')
-rw-r--r-- | tools/unitctl/unit-client-rs/src/control_socket_address.rs | 571 |
1 files changed, 571 insertions, 0 deletions
diff --git a/tools/unitctl/unit-client-rs/src/control_socket_address.rs b/tools/unitctl/unit-client-rs/src/control_socket_address.rs new file mode 100644 index 00000000..b9ae5afc --- /dev/null +++ b/tools/unitctl/unit-client-rs/src/control_socket_address.rs @@ -0,0 +1,571 @@ +use crate::control_socket_address::ControlSocket::{TcpSocket, UnixLocalAbstractSocket, UnixLocalSocket}; +use crate::control_socket_address::ControlSocketScheme::{HTTP, HTTPS}; +use crate::unit_client::UnitClientError; +use hyper::http::uri::{Authority, PathAndQuery}; +use hyper::Uri; +use std::fmt::{Display, Formatter}; +use std::fs; +use std::os::unix::fs::FileTypeExt; +use std::path::{PathBuf, MAIN_SEPARATOR}; + +type AbstractSocketName = String; +type UnixSocketPath = PathBuf; +type Port = u16; + +#[derive(Debug, Clone)] +pub enum ControlSocket { + UnixLocalAbstractSocket(AbstractSocketName), + UnixLocalSocket(UnixSocketPath), + TcpSocket(Uri), +} + +#[derive(Debug)] +pub enum ControlSocketScheme { + HTTP, + HTTPS, +} + +impl ControlSocketScheme { + fn port(&self) -> Port { + match self { + HTTP => 80, + HTTPS => 443, + } + } +} + +impl ControlSocket { + pub fn socket_scheme(&self) -> ControlSocketScheme { + match self { + UnixLocalAbstractSocket(_) => ControlSocketScheme::HTTP, + UnixLocalSocket(_) => ControlSocketScheme::HTTP, + TcpSocket(uri) => match uri.scheme_str().expect("Scheme should not be None") { + "http" => ControlSocketScheme::HTTP, + "https" => ControlSocketScheme::HTTPS, + _ => unreachable!("Scheme should be http or https"), + }, + } + } + + pub fn create_uri_with_path(&self, str_path: &str) -> Uri { + match self { + UnixLocalAbstractSocket(name) => { + let socket_path = PathBuf::from(format!("@{}", name)); + hyperlocal::Uri::new(socket_path, str_path).into() + } + UnixLocalSocket(socket_path) => hyperlocal::Uri::new(socket_path, str_path).into(), + TcpSocket(uri) => { + if str_path.is_empty() { + uri.clone() + } else { + let authority = uri.authority().expect("Authority should not be None"); + Uri::builder() + .scheme(uri.scheme_str().expect("Scheme should not be None")) + .authority(authority.clone()) + .path_and_query(str_path) + .build() + .expect("URI should be valid") + } + } + } + } +} + +impl Display for ControlSocket { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + UnixLocalAbstractSocket(name) => f.write_fmt(format_args!("unix:@{}", name)), + UnixLocalSocket(path) => f.write_fmt(format_args!("unix:{}", path.to_string_lossy())), + TcpSocket(uri) => uri.fmt(f), + } + } +} + +impl From<ControlSocket> for String { + fn from(val: ControlSocket) -> Self { + val.to_string() + } +} + +impl From<ControlSocket> for PathBuf { + fn from(val: ControlSocket) -> Self { + match val { + UnixLocalAbstractSocket(socket_name) => PathBuf::from(format!("@{}", socket_name)), + UnixLocalSocket(socket_path) => socket_path, + TcpSocket(_) => PathBuf::default(), + } + } +} + +impl From<ControlSocket> for Uri { + fn from(val: ControlSocket) -> Self { + val.create_uri_with_path("") + } +} + +impl ControlSocket { + pub fn validate_http_address(uri: Uri) -> Result<(), UnitClientError> { + let http_address = uri.to_string(); + if uri.authority().is_none() { + return Err(UnitClientError::TcpSocketAddressParseError { + message: "No authority found in socket address".to_string(), + control_socket_address: http_address, + }); + } + if uri.port_u16().is_none() { + return Err(UnitClientError::TcpSocketAddressNoPortError { + control_socket_address: http_address, + }); + } + if !(uri.path().is_empty() || uri.path().eq("/")) { + return Err(UnitClientError::TcpSocketAddressParseError { + message: format!("Path is not empty or is not / [path={}]", uri.path()), + control_socket_address: http_address, + }); + } + + Ok(()) + } + + pub fn validate_unix_address(socket: PathBuf) -> Result<(), UnitClientError> { + if !socket.exists() { + return Err(UnitClientError::UnixSocketNotFound { + control_socket_address: socket.to_string_lossy().to_string(), + }); + } + let metadata = fs::metadata(&socket).map_err(|error| UnitClientError::UnixSocketAddressError { + source: error, + control_socket_address: socket.to_string_lossy().to_string(), + })?; + let file_type = metadata.file_type(); + if !file_type.is_socket() { + return Err(UnitClientError::UnixSocketAddressError { + source: std::io::Error::new(std::io::ErrorKind::Other, "Control socket path is not a socket"), + control_socket_address: socket.to_string_lossy().to_string(), + }); + } + + Ok(()) + } + + pub fn validate(&self) -> Result<Self, UnitClientError> { + match self { + UnixLocalAbstractSocket(socket_name) => { + let socket_path = PathBuf::from(format!("@{}", socket_name)); + Self::validate_unix_address(socket_path.clone()) + } + UnixLocalSocket(socket_path) => Self::validate_unix_address(socket_path.clone()), + TcpSocket(socket_uri) => Self::validate_http_address(socket_uri.clone()), + } + .map(|_| self.to_owned()) + } + + fn normalize_and_parse_http_address(http_address: String) -> Result<Uri, UnitClientError> { + // Convert *:1 style network addresses to URI format + let address = if http_address.starts_with("*:") { + http_address.replacen("*:", "http://127.0.0.1:", 1) + // Add scheme if not present + } else if !(http_address.starts_with("http://") || http_address.starts_with("https://")) { + format!("http://{}", http_address) + } else { + http_address.to_owned() + }; + + let is_https = address.starts_with("https://"); + + let parsed_uri = + Uri::try_from(address.as_str()).map_err(|error| UnitClientError::TcpSocketAddressUriError { + source: error, + control_socket_address: address, + })?; + let authority = parsed_uri.authority().expect("Authority should not be None"); + let expected_port = if is_https { HTTPS.port() } else { HTTP.port() }; + let normalized_authority = match authority.port_u16() { + Some(_) => authority.to_owned(), + None => { + let host = format!("{}:{}", authority.host(), expected_port); + Authority::try_from(host.as_str()).expect("Authority should be valid") + } + }; + + let normalized_uri = Uri::builder() + .scheme(parsed_uri.scheme_str().expect("Scheme should not be None")) + .authority(normalized_authority) + .path_and_query(PathAndQuery::from_static("")) + .build() + .map_err(|error| UnitClientError::TcpSocketAddressParseError { + message: error.to_string(), + control_socket_address: http_address.clone(), + })?; + + Ok(normalized_uri) + } + + /// Flexibly parse a textual representation of a socket address + fn parse_address<S: Into<String>>(socket_address: S) -> Result<Self, UnitClientError> { + let full_socket_address: String = socket_address.into(); + let socket_prefix = "unix:"; + let socket_uri_prefix = "unix://"; + let mut buf = String::with_capacity(socket_prefix.len()); + for (i, c) in full_socket_address.char_indices() { + // Abstract unix socket with no prefix + if i == 0 && c == '@' { + return Ok(UnixLocalAbstractSocket(full_socket_address[1..].to_string())); + } + buf.push(c); + // Unix socket with prefix + if i == socket_prefix.len() - 1 && buf.eq(socket_prefix) { + let path_text = full_socket_address[socket_prefix.len()..].to_string(); + // Return here if this URI does not have a scheme followed by double slashes + if !path_text.starts_with("//") { + return match path_text.strip_prefix('@') { + Some(name) => Ok(UnixLocalAbstractSocket(name.to_string())), + None => { + let path = PathBuf::from(path_text); + Ok(UnixLocalSocket(path)) + } + }; + } + } + + // Unix socket with URI prefix + if i == socket_uri_prefix.len() - 1 && buf.eq(socket_uri_prefix) { + let uri = Uri::try_from(full_socket_address.as_str()).map_err(|error| { + UnitClientError::TcpSocketAddressParseError { + message: error.to_string(), + control_socket_address: full_socket_address.clone(), + } + })?; + return ControlSocket::try_from(uri); + } + } + + /* Sockets on Windows are not supported, so there is no need to check + * if the socket address is a valid path, so we can do this shortcut + * here to see if a path was specified without a unix: prefix. */ + if buf.starts_with(MAIN_SEPARATOR) { + let path = PathBuf::from(buf); + return Ok(UnixLocalSocket(path)); + } + + let uri = Self::normalize_and_parse_http_address(buf)?; + Ok(TcpSocket(uri)) + } + + pub fn is_local_socket(&self) -> bool { + match self { + UnixLocalAbstractSocket(_) | UnixLocalSocket(_) => true, + TcpSocket(_) => false, + } + } +} + +impl TryFrom<String> for ControlSocket { + type Error = UnitClientError; + + fn try_from(socket_address: String) -> Result<Self, Self::Error> { + ControlSocket::parse_address(socket_address.as_str()) + } +} + +impl TryFrom<&str> for ControlSocket { + type Error = UnitClientError; + + fn try_from(socket_address: &str) -> Result<Self, Self::Error> { + ControlSocket::parse_address(socket_address) + } +} + +impl TryFrom<Uri> for ControlSocket { + type Error = UnitClientError; + + fn try_from(socket_uri: Uri) -> Result<Self, Self::Error> { + match socket_uri.scheme_str() { + // URIs with the unix scheme will have a hostname that is a hex encoded string + // representing the path to the socket + Some("unix") => { + let host = match socket_uri.host() { + Some(host) => host, + None => { + return Err(UnitClientError::TcpSocketAddressParseError { + message: "No host found in socket address".to_string(), + control_socket_address: socket_uri.to_string(), + }) + } + }; + let bytes = hex::decode(host).map_err(|error| UnitClientError::TcpSocketAddressParseError { + message: error.to_string(), + control_socket_address: socket_uri.to_string(), + })?; + let path = String::from_utf8_lossy(&bytes); + ControlSocket::parse_address(path) + } + Some("http") | Some("https") => Ok(TcpSocket(socket_uri)), + Some(unknown) => Err(UnitClientError::TcpSocketAddressParseError { + message: format!("Unsupported scheme found in socket address: {}", unknown).to_string(), + control_socket_address: socket_uri.to_string(), + }), + None => Err(UnitClientError::TcpSocketAddressParseError { + message: "No scheme found in socket address".to_string(), + control_socket_address: socket_uri.to_string(), + }), + } + } +} + +#[cfg(test)] +mod tests { + use rand::distributions::{Alphanumeric, DistString}; + use std::env::temp_dir; + use std::fmt::Display; + use std::io; + use std::os::unix::net::UnixListener; + + use super::*; + + struct TempSocket { + socket_path: PathBuf, + _listener: UnixListener, + } + + impl TempSocket { + fn shutdown(&mut self) -> io::Result<()> { + fs::remove_file(&self.socket_path) + } + } + + impl Display for TempSocket { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "unix:{}", self.socket_path.to_string_lossy().to_string()) + } + } + + impl Drop for TempSocket { + fn drop(&mut self) { + self.shutdown() + .expect(format!("Unable to shutdown socket {}", self.socket_path.to_string_lossy()).as_str()); + } + } + + #[test] + fn will_error_with_nonexistent_unix_socket() { + let socket_address = "unix:/tmp/some_random_filename_that_doesnt_exist.sock"; + let control_socket = + ControlSocket::try_from(socket_address).expect("No error should be returned until validate() is called"); + assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); + assert!(control_socket.validate().is_err(), "Socket should not be valid"); + } + + #[test] + fn can_parse_socket_with_prefix() { + let temp_socket = create_file_socket().expect("Unable to create socket"); + let control_socket = ControlSocket::try_from(temp_socket.to_string()).expect("Error parsing good socket path"); + assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); + if let Err(e) = control_socket.validate() { + panic!("Socket should be valid: {}", e); + } + } + + #[test] + fn can_parse_socket_from_uri() { + let temp_socket = create_file_socket().expect("Unable to create socket"); + let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into(); + let control_socket = ControlSocket::try_from(uri).expect("Error parsing good socket path"); + assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); + if let Err(e) = control_socket.validate() { + panic!("Socket should be valid: {}", e); + } + } + + #[test] + fn can_parse_socket_from_uri_text() { + let temp_socket = create_file_socket().expect("Unable to create socket"); + let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into(); + let control_socket = ControlSocket::parse_address(uri.to_string()).expect("Error parsing good socket path"); + assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); + if let Err(e) = control_socket.validate() { + panic!("Socket for input text should be valid: {}", e); + } + } + + #[test] + #[cfg(target_os = "linux")] + fn can_parse_abstract_socket_from_uri() { + let temp_socket = create_abstract_socket().expect("Unable to create socket"); + let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into(); + let control_socket = ControlSocket::try_from(uri).expect("Error parsing good socket path"); + assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); + if let Err(e) = control_socket.validate() { + panic!("Socket should be valid: {}", e); + } + } + + #[test] + #[cfg(target_os = "linux")] + fn can_parse_abstract_socket_from_uri_text() { + let temp_socket = create_abstract_socket().expect("Unable to create socket"); + let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into(); + let control_socket = ControlSocket::parse_address(uri.to_string()).expect("Error parsing good socket path"); + assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); + if let Err(e) = control_socket.validate() { + panic!("Socket should be valid: {}", e); + } + } + + #[test] + fn can_parse_socket_without_prefix() { + let temp_socket = create_file_socket().expect("Unable to create socket"); + let control_socket = ControlSocket::try_from(temp_socket.socket_path.to_string_lossy().to_string()) + .expect("Error parsing good socket path"); + assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); + if let Err(e) = control_socket.validate() { + panic!("Socket should be valid: {}", e); + } + } + + #[cfg(target_os = "linux")] + #[test] + fn can_parse_abstract_socket() { + let temp_socket = create_abstract_socket().expect("Unable to create socket"); + let control_socket = ControlSocket::try_from(temp_socket.to_string()).expect("Error parsing good socket path"); + assert!(control_socket.is_local_socket(), "Not parsed as a local socket"); + if let Err(e) = control_socket.validate() { + panic!("Socket should be valid: {}", e); + } + } + + #[test] + fn can_normalize_good_http_socket_addresses() { + let valid_socket_addresses = vec![ + "http://127.0.0.1:8080", + "https://127.0.0.1:8080", + "http://127.0.0.1:8080/", + "127.0.0.1:8080", + "http://0.0.0.0:8080", + "https://0.0.0.0:8080", + "http://0.0.0.0:8080/", + "0.0.0.0:8080", + "http://localhost:8080", + "https://localhost:8080", + "http://localhost:8080/", + "localhost:8080", + "http://[::1]:8080", + "https://[::1]:8080", + "http://[::1]:8080/", + "[::1]:8080", + "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080", + "https://[0000:0000:0000:0000:0000:0000:0000:0000]:8080", + "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080/", + "[0000:0000:0000:0000:0000:0000:0000:0000]:8080", + ]; + for socket_address in valid_socket_addresses { + let mut expected = if socket_address.starts_with("http") { + socket_address.to_string().trim_end_matches('/').to_string() + } else { + format!("http://{}", socket_address).trim_end_matches('/').to_string() + }; + expected.push('/'); + + let control_socket = ControlSocket::try_from(socket_address).expect("Error parsing good socket path"); + assert!(!control_socket.is_local_socket(), "Not parsed as a local socket"); + if let Err(e) = control_socket.validate() { + panic!("Socket should be valid: {}", e); + } + } + } + + #[test] + fn can_normalize_wildcard_http_socket_address() { + let socket_address = "*:8080"; + let expected = "http://127.0.0.1:8080/"; + let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string()); + let normalized = normalized_result + .expect("Unable to normalize socket address") + .to_string(); + assert_eq!(normalized, expected); + } + + #[test] + fn can_normalize_http_socket_address_with_no_port() { + let socket_address = "http://localhost"; + let expected = "http://localhost:80/"; + let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string()); + let normalized = normalized_result + .expect("Unable to normalize socket address") + .to_string(); + assert_eq!(normalized, expected); + } + + #[test] + fn can_normalize_https_socket_address_with_no_port() { + let socket_address = "https://localhost"; + let expected = "https://localhost:443/"; + let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string()); + let normalized = normalized_result + .expect("Unable to normalize socket address") + .to_string(); + assert_eq!(normalized, expected); + } + + #[test] + fn can_parse_http_addresses() { + let valid_socket_addresses = vec![ + "http://127.0.0.1:8080", + "https://127.0.0.1:8080", + "http://127.0.0.1:8080/", + "127.0.0.1:8080", + "http://0.0.0.0:8080", + "https://0.0.0.0:8080", + "http://0.0.0.0:8080/", + "0.0.0.0:8080", + "http://localhost:8080", + "https://localhost:8080", + "http://localhost:8080/", + "localhost:8080", + "http://[::1]:8080", + "https://[::1]:8080", + "http://[::1]:8080/", + "[::1]:8080", + "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080", + "https://[0000:0000:0000:0000:0000:0000:0000:0000]:8080", + "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080/", + "[0000:0000:0000:0000:0000:0000:0000:0000]:8080", + ]; + for socket_address in valid_socket_addresses { + let mut expected = if socket_address.starts_with("http") { + socket_address.to_string().trim_end_matches('/').to_string() + } else { + format!("http://{}", socket_address).trim_end_matches('/').to_string() + }; + expected.push('/'); + + let normalized = ControlSocket::normalize_and_parse_http_address(socket_address.to_string()) + .expect("Unable to normalize socket address") + .to_string(); + assert_eq!(normalized, expected); + } + } + + fn create_file_socket() -> Result<TempSocket, io::Error> { + let random = Alphanumeric.sample_string(&mut rand::thread_rng(), 10); + let socket_name = format!("unit-client-socket-test-{}.sock", random); + let socket_path = temp_dir().join(socket_name); + let listener = UnixListener::bind(&socket_path)?; + Ok(TempSocket { + socket_path, + _listener: listener, + }) + } + + #[cfg(target_os = "linux")] + fn create_abstract_socket() -> Result<TempSocket, io::Error> { + let random = Alphanumeric.sample_string(&mut rand::thread_rng(), 10); + let socket_name = format!("@unit-client-socket-test-{}.sock", random); + let socket_path = PathBuf::from(socket_name); + let listener = UnixListener::bind(&socket_path)?; + Ok(TempSocket { + socket_path, + _listener: listener, + }) + } +} |