summaryrefslogtreecommitdiffhomepage
path: root/tools/unitctl/unit-client-rs/src/control_socket_address.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tools/unitctl/unit-client-rs/src/control_socket_address.rs')
-rw-r--r--tools/unitctl/unit-client-rs/src/control_socket_address.rs571
1 files changed, 571 insertions, 0 deletions
diff --git a/tools/unitctl/unit-client-rs/src/control_socket_address.rs b/tools/unitctl/unit-client-rs/src/control_socket_address.rs
new file mode 100644
index 00000000..b9ae5afc
--- /dev/null
+++ b/tools/unitctl/unit-client-rs/src/control_socket_address.rs
@@ -0,0 +1,571 @@
+use crate::control_socket_address::ControlSocket::{TcpSocket, UnixLocalAbstractSocket, UnixLocalSocket};
+use crate::control_socket_address::ControlSocketScheme::{HTTP, HTTPS};
+use crate::unit_client::UnitClientError;
+use hyper::http::uri::{Authority, PathAndQuery};
+use hyper::Uri;
+use std::fmt::{Display, Formatter};
+use std::fs;
+use std::os::unix::fs::FileTypeExt;
+use std::path::{PathBuf, MAIN_SEPARATOR};
+
+type AbstractSocketName = String;
+type UnixSocketPath = PathBuf;
+type Port = u16;
+
+#[derive(Debug, Clone)]
+pub enum ControlSocket {
+ UnixLocalAbstractSocket(AbstractSocketName),
+ UnixLocalSocket(UnixSocketPath),
+ TcpSocket(Uri),
+}
+
+#[derive(Debug)]
+pub enum ControlSocketScheme {
+ HTTP,
+ HTTPS,
+}
+
+impl ControlSocketScheme {
+ fn port(&self) -> Port {
+ match self {
+ HTTP => 80,
+ HTTPS => 443,
+ }
+ }
+}
+
+impl ControlSocket {
+ pub fn socket_scheme(&self) -> ControlSocketScheme {
+ match self {
+ UnixLocalAbstractSocket(_) => ControlSocketScheme::HTTP,
+ UnixLocalSocket(_) => ControlSocketScheme::HTTP,
+ TcpSocket(uri) => match uri.scheme_str().expect("Scheme should not be None") {
+ "http" => ControlSocketScheme::HTTP,
+ "https" => ControlSocketScheme::HTTPS,
+ _ => unreachable!("Scheme should be http or https"),
+ },
+ }
+ }
+
+ pub fn create_uri_with_path(&self, str_path: &str) -> Uri {
+ match self {
+ UnixLocalAbstractSocket(name) => {
+ let socket_path = PathBuf::from(format!("@{}", name));
+ hyperlocal::Uri::new(socket_path, str_path).into()
+ }
+ UnixLocalSocket(socket_path) => hyperlocal::Uri::new(socket_path, str_path).into(),
+ TcpSocket(uri) => {
+ if str_path.is_empty() {
+ uri.clone()
+ } else {
+ let authority = uri.authority().expect("Authority should not be None");
+ Uri::builder()
+ .scheme(uri.scheme_str().expect("Scheme should not be None"))
+ .authority(authority.clone())
+ .path_and_query(str_path)
+ .build()
+ .expect("URI should be valid")
+ }
+ }
+ }
+ }
+}
+
+impl Display for ControlSocket {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ UnixLocalAbstractSocket(name) => f.write_fmt(format_args!("unix:@{}", name)),
+ UnixLocalSocket(path) => f.write_fmt(format_args!("unix:{}", path.to_string_lossy())),
+ TcpSocket(uri) => uri.fmt(f),
+ }
+ }
+}
+
+impl From<ControlSocket> for String {
+ fn from(val: ControlSocket) -> Self {
+ val.to_string()
+ }
+}
+
+impl From<ControlSocket> for PathBuf {
+ fn from(val: ControlSocket) -> Self {
+ match val {
+ UnixLocalAbstractSocket(socket_name) => PathBuf::from(format!("@{}", socket_name)),
+ UnixLocalSocket(socket_path) => socket_path,
+ TcpSocket(_) => PathBuf::default(),
+ }
+ }
+}
+
+impl From<ControlSocket> for Uri {
+ fn from(val: ControlSocket) -> Self {
+ val.create_uri_with_path("")
+ }
+}
+
+impl ControlSocket {
+ pub fn validate_http_address(uri: Uri) -> Result<(), UnitClientError> {
+ let http_address = uri.to_string();
+ if uri.authority().is_none() {
+ return Err(UnitClientError::TcpSocketAddressParseError {
+ message: "No authority found in socket address".to_string(),
+ control_socket_address: http_address,
+ });
+ }
+ if uri.port_u16().is_none() {
+ return Err(UnitClientError::TcpSocketAddressNoPortError {
+ control_socket_address: http_address,
+ });
+ }
+ if !(uri.path().is_empty() || uri.path().eq("/")) {
+ return Err(UnitClientError::TcpSocketAddressParseError {
+ message: format!("Path is not empty or is not / [path={}]", uri.path()),
+ control_socket_address: http_address,
+ });
+ }
+
+ Ok(())
+ }
+
+ pub fn validate_unix_address(socket: PathBuf) -> Result<(), UnitClientError> {
+ if !socket.exists() {
+ return Err(UnitClientError::UnixSocketNotFound {
+ control_socket_address: socket.to_string_lossy().to_string(),
+ });
+ }
+ let metadata = fs::metadata(&socket).map_err(|error| UnitClientError::UnixSocketAddressError {
+ source: error,
+ control_socket_address: socket.to_string_lossy().to_string(),
+ })?;
+ let file_type = metadata.file_type();
+ if !file_type.is_socket() {
+ return Err(UnitClientError::UnixSocketAddressError {
+ source: std::io::Error::new(std::io::ErrorKind::Other, "Control socket path is not a socket"),
+ control_socket_address: socket.to_string_lossy().to_string(),
+ });
+ }
+
+ Ok(())
+ }
+
+ pub fn validate(&self) -> Result<Self, UnitClientError> {
+ match self {
+ UnixLocalAbstractSocket(socket_name) => {
+ let socket_path = PathBuf::from(format!("@{}", socket_name));
+ Self::validate_unix_address(socket_path.clone())
+ }
+ UnixLocalSocket(socket_path) => Self::validate_unix_address(socket_path.clone()),
+ TcpSocket(socket_uri) => Self::validate_http_address(socket_uri.clone()),
+ }
+ .map(|_| self.to_owned())
+ }
+
+ fn normalize_and_parse_http_address(http_address: String) -> Result<Uri, UnitClientError> {
+ // Convert *:1 style network addresses to URI format
+ let address = if http_address.starts_with("*:") {
+ http_address.replacen("*:", "http://127.0.0.1:", 1)
+ // Add scheme if not present
+ } else if !(http_address.starts_with("http://") || http_address.starts_with("https://")) {
+ format!("http://{}", http_address)
+ } else {
+ http_address.to_owned()
+ };
+
+ let is_https = address.starts_with("https://");
+
+ let parsed_uri =
+ Uri::try_from(address.as_str()).map_err(|error| UnitClientError::TcpSocketAddressUriError {
+ source: error,
+ control_socket_address: address,
+ })?;
+ let authority = parsed_uri.authority().expect("Authority should not be None");
+ let expected_port = if is_https { HTTPS.port() } else { HTTP.port() };
+ let normalized_authority = match authority.port_u16() {
+ Some(_) => authority.to_owned(),
+ None => {
+ let host = format!("{}:{}", authority.host(), expected_port);
+ Authority::try_from(host.as_str()).expect("Authority should be valid")
+ }
+ };
+
+ let normalized_uri = Uri::builder()
+ .scheme(parsed_uri.scheme_str().expect("Scheme should not be None"))
+ .authority(normalized_authority)
+ .path_and_query(PathAndQuery::from_static(""))
+ .build()
+ .map_err(|error| UnitClientError::TcpSocketAddressParseError {
+ message: error.to_string(),
+ control_socket_address: http_address.clone(),
+ })?;
+
+ Ok(normalized_uri)
+ }
+
+ /// Flexibly parse a textual representation of a socket address
+ fn parse_address<S: Into<String>>(socket_address: S) -> Result<Self, UnitClientError> {
+ let full_socket_address: String = socket_address.into();
+ let socket_prefix = "unix:";
+ let socket_uri_prefix = "unix://";
+ let mut buf = String::with_capacity(socket_prefix.len());
+ for (i, c) in full_socket_address.char_indices() {
+ // Abstract unix socket with no prefix
+ if i == 0 && c == '@' {
+ return Ok(UnixLocalAbstractSocket(full_socket_address[1..].to_string()));
+ }
+ buf.push(c);
+ // Unix socket with prefix
+ if i == socket_prefix.len() - 1 && buf.eq(socket_prefix) {
+ let path_text = full_socket_address[socket_prefix.len()..].to_string();
+ // Return here if this URI does not have a scheme followed by double slashes
+ if !path_text.starts_with("//") {
+ return match path_text.strip_prefix('@') {
+ Some(name) => Ok(UnixLocalAbstractSocket(name.to_string())),
+ None => {
+ let path = PathBuf::from(path_text);
+ Ok(UnixLocalSocket(path))
+ }
+ };
+ }
+ }
+
+ // Unix socket with URI prefix
+ if i == socket_uri_prefix.len() - 1 && buf.eq(socket_uri_prefix) {
+ let uri = Uri::try_from(full_socket_address.as_str()).map_err(|error| {
+ UnitClientError::TcpSocketAddressParseError {
+ message: error.to_string(),
+ control_socket_address: full_socket_address.clone(),
+ }
+ })?;
+ return ControlSocket::try_from(uri);
+ }
+ }
+
+ /* Sockets on Windows are not supported, so there is no need to check
+ * if the socket address is a valid path, so we can do this shortcut
+ * here to see if a path was specified without a unix: prefix. */
+ if buf.starts_with(MAIN_SEPARATOR) {
+ let path = PathBuf::from(buf);
+ return Ok(UnixLocalSocket(path));
+ }
+
+ let uri = Self::normalize_and_parse_http_address(buf)?;
+ Ok(TcpSocket(uri))
+ }
+
+ pub fn is_local_socket(&self) -> bool {
+ match self {
+ UnixLocalAbstractSocket(_) | UnixLocalSocket(_) => true,
+ TcpSocket(_) => false,
+ }
+ }
+}
+
+impl TryFrom<String> for ControlSocket {
+ type Error = UnitClientError;
+
+ fn try_from(socket_address: String) -> Result<Self, Self::Error> {
+ ControlSocket::parse_address(socket_address.as_str())
+ }
+}
+
+impl TryFrom<&str> for ControlSocket {
+ type Error = UnitClientError;
+
+ fn try_from(socket_address: &str) -> Result<Self, Self::Error> {
+ ControlSocket::parse_address(socket_address)
+ }
+}
+
+impl TryFrom<Uri> for ControlSocket {
+ type Error = UnitClientError;
+
+ fn try_from(socket_uri: Uri) -> Result<Self, Self::Error> {
+ match socket_uri.scheme_str() {
+ // URIs with the unix scheme will have a hostname that is a hex encoded string
+ // representing the path to the socket
+ Some("unix") => {
+ let host = match socket_uri.host() {
+ Some(host) => host,
+ None => {
+ return Err(UnitClientError::TcpSocketAddressParseError {
+ message: "No host found in socket address".to_string(),
+ control_socket_address: socket_uri.to_string(),
+ })
+ }
+ };
+ let bytes = hex::decode(host).map_err(|error| UnitClientError::TcpSocketAddressParseError {
+ message: error.to_string(),
+ control_socket_address: socket_uri.to_string(),
+ })?;
+ let path = String::from_utf8_lossy(&bytes);
+ ControlSocket::parse_address(path)
+ }
+ Some("http") | Some("https") => Ok(TcpSocket(socket_uri)),
+ Some(unknown) => Err(UnitClientError::TcpSocketAddressParseError {
+ message: format!("Unsupported scheme found in socket address: {}", unknown).to_string(),
+ control_socket_address: socket_uri.to_string(),
+ }),
+ None => Err(UnitClientError::TcpSocketAddressParseError {
+ message: "No scheme found in socket address".to_string(),
+ control_socket_address: socket_uri.to_string(),
+ }),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use rand::distributions::{Alphanumeric, DistString};
+ use std::env::temp_dir;
+ use std::fmt::Display;
+ use std::io;
+ use std::os::unix::net::UnixListener;
+
+ use super::*;
+
+ struct TempSocket {
+ socket_path: PathBuf,
+ _listener: UnixListener,
+ }
+
+ impl TempSocket {
+ fn shutdown(&mut self) -> io::Result<()> {
+ fs::remove_file(&self.socket_path)
+ }
+ }
+
+ impl Display for TempSocket {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "unix:{}", self.socket_path.to_string_lossy().to_string())
+ }
+ }
+
+ impl Drop for TempSocket {
+ fn drop(&mut self) {
+ self.shutdown()
+ .expect(format!("Unable to shutdown socket {}", self.socket_path.to_string_lossy()).as_str());
+ }
+ }
+
+ #[test]
+ fn will_error_with_nonexistent_unix_socket() {
+ let socket_address = "unix:/tmp/some_random_filename_that_doesnt_exist.sock";
+ let control_socket =
+ ControlSocket::try_from(socket_address).expect("No error should be returned until validate() is called");
+ assert!(control_socket.is_local_socket(), "Not parsed as a local socket");
+ assert!(control_socket.validate().is_err(), "Socket should not be valid");
+ }
+
+ #[test]
+ fn can_parse_socket_with_prefix() {
+ let temp_socket = create_file_socket().expect("Unable to create socket");
+ let control_socket = ControlSocket::try_from(temp_socket.to_string()).expect("Error parsing good socket path");
+ assert!(control_socket.is_local_socket(), "Not parsed as a local socket");
+ if let Err(e) = control_socket.validate() {
+ panic!("Socket should be valid: {}", e);
+ }
+ }
+
+ #[test]
+ fn can_parse_socket_from_uri() {
+ let temp_socket = create_file_socket().expect("Unable to create socket");
+ let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into();
+ let control_socket = ControlSocket::try_from(uri).expect("Error parsing good socket path");
+ assert!(control_socket.is_local_socket(), "Not parsed as a local socket");
+ if let Err(e) = control_socket.validate() {
+ panic!("Socket should be valid: {}", e);
+ }
+ }
+
+ #[test]
+ fn can_parse_socket_from_uri_text() {
+ let temp_socket = create_file_socket().expect("Unable to create socket");
+ let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into();
+ let control_socket = ControlSocket::parse_address(uri.to_string()).expect("Error parsing good socket path");
+ assert!(control_socket.is_local_socket(), "Not parsed as a local socket");
+ if let Err(e) = control_socket.validate() {
+ panic!("Socket for input text should be valid: {}", e);
+ }
+ }
+
+ #[test]
+ #[cfg(target_os = "linux")]
+ fn can_parse_abstract_socket_from_uri() {
+ let temp_socket = create_abstract_socket().expect("Unable to create socket");
+ let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into();
+ let control_socket = ControlSocket::try_from(uri).expect("Error parsing good socket path");
+ assert!(control_socket.is_local_socket(), "Not parsed as a local socket");
+ if let Err(e) = control_socket.validate() {
+ panic!("Socket should be valid: {}", e);
+ }
+ }
+
+ #[test]
+ #[cfg(target_os = "linux")]
+ fn can_parse_abstract_socket_from_uri_text() {
+ let temp_socket = create_abstract_socket().expect("Unable to create socket");
+ let uri: Uri = hyperlocal::Uri::new(temp_socket.socket_path.clone(), "").into();
+ let control_socket = ControlSocket::parse_address(uri.to_string()).expect("Error parsing good socket path");
+ assert!(control_socket.is_local_socket(), "Not parsed as a local socket");
+ if let Err(e) = control_socket.validate() {
+ panic!("Socket should be valid: {}", e);
+ }
+ }
+
+ #[test]
+ fn can_parse_socket_without_prefix() {
+ let temp_socket = create_file_socket().expect("Unable to create socket");
+ let control_socket = ControlSocket::try_from(temp_socket.socket_path.to_string_lossy().to_string())
+ .expect("Error parsing good socket path");
+ assert!(control_socket.is_local_socket(), "Not parsed as a local socket");
+ if let Err(e) = control_socket.validate() {
+ panic!("Socket should be valid: {}", e);
+ }
+ }
+
+ #[cfg(target_os = "linux")]
+ #[test]
+ fn can_parse_abstract_socket() {
+ let temp_socket = create_abstract_socket().expect("Unable to create socket");
+ let control_socket = ControlSocket::try_from(temp_socket.to_string()).expect("Error parsing good socket path");
+ assert!(control_socket.is_local_socket(), "Not parsed as a local socket");
+ if let Err(e) = control_socket.validate() {
+ panic!("Socket should be valid: {}", e);
+ }
+ }
+
+ #[test]
+ fn can_normalize_good_http_socket_addresses() {
+ let valid_socket_addresses = vec![
+ "http://127.0.0.1:8080",
+ "https://127.0.0.1:8080",
+ "http://127.0.0.1:8080/",
+ "127.0.0.1:8080",
+ "http://0.0.0.0:8080",
+ "https://0.0.0.0:8080",
+ "http://0.0.0.0:8080/",
+ "0.0.0.0:8080",
+ "http://localhost:8080",
+ "https://localhost:8080",
+ "http://localhost:8080/",
+ "localhost:8080",
+ "http://[::1]:8080",
+ "https://[::1]:8080",
+ "http://[::1]:8080/",
+ "[::1]:8080",
+ "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080",
+ "https://[0000:0000:0000:0000:0000:0000:0000:0000]:8080",
+ "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080/",
+ "[0000:0000:0000:0000:0000:0000:0000:0000]:8080",
+ ];
+ for socket_address in valid_socket_addresses {
+ let mut expected = if socket_address.starts_with("http") {
+ socket_address.to_string().trim_end_matches('/').to_string()
+ } else {
+ format!("http://{}", socket_address).trim_end_matches('/').to_string()
+ };
+ expected.push('/');
+
+ let control_socket = ControlSocket::try_from(socket_address).expect("Error parsing good socket path");
+ assert!(!control_socket.is_local_socket(), "Not parsed as a local socket");
+ if let Err(e) = control_socket.validate() {
+ panic!("Socket should be valid: {}", e);
+ }
+ }
+ }
+
+ #[test]
+ fn can_normalize_wildcard_http_socket_address() {
+ let socket_address = "*:8080";
+ let expected = "http://127.0.0.1:8080/";
+ let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string());
+ let normalized = normalized_result
+ .expect("Unable to normalize socket address")
+ .to_string();
+ assert_eq!(normalized, expected);
+ }
+
+ #[test]
+ fn can_normalize_http_socket_address_with_no_port() {
+ let socket_address = "http://localhost";
+ let expected = "http://localhost:80/";
+ let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string());
+ let normalized = normalized_result
+ .expect("Unable to normalize socket address")
+ .to_string();
+ assert_eq!(normalized, expected);
+ }
+
+ #[test]
+ fn can_normalize_https_socket_address_with_no_port() {
+ let socket_address = "https://localhost";
+ let expected = "https://localhost:443/";
+ let normalized_result = ControlSocket::normalize_and_parse_http_address(socket_address.to_string());
+ let normalized = normalized_result
+ .expect("Unable to normalize socket address")
+ .to_string();
+ assert_eq!(normalized, expected);
+ }
+
+ #[test]
+ fn can_parse_http_addresses() {
+ let valid_socket_addresses = vec![
+ "http://127.0.0.1:8080",
+ "https://127.0.0.1:8080",
+ "http://127.0.0.1:8080/",
+ "127.0.0.1:8080",
+ "http://0.0.0.0:8080",
+ "https://0.0.0.0:8080",
+ "http://0.0.0.0:8080/",
+ "0.0.0.0:8080",
+ "http://localhost:8080",
+ "https://localhost:8080",
+ "http://localhost:8080/",
+ "localhost:8080",
+ "http://[::1]:8080",
+ "https://[::1]:8080",
+ "http://[::1]:8080/",
+ "[::1]:8080",
+ "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080",
+ "https://[0000:0000:0000:0000:0000:0000:0000:0000]:8080",
+ "http://[0000:0000:0000:0000:0000:0000:0000:0000]:8080/",
+ "[0000:0000:0000:0000:0000:0000:0000:0000]:8080",
+ ];
+ for socket_address in valid_socket_addresses {
+ let mut expected = if socket_address.starts_with("http") {
+ socket_address.to_string().trim_end_matches('/').to_string()
+ } else {
+ format!("http://{}", socket_address).trim_end_matches('/').to_string()
+ };
+ expected.push('/');
+
+ let normalized = ControlSocket::normalize_and_parse_http_address(socket_address.to_string())
+ .expect("Unable to normalize socket address")
+ .to_string();
+ assert_eq!(normalized, expected);
+ }
+ }
+
+ fn create_file_socket() -> Result<TempSocket, io::Error> {
+ let random = Alphanumeric.sample_string(&mut rand::thread_rng(), 10);
+ let socket_name = format!("unit-client-socket-test-{}.sock", random);
+ let socket_path = temp_dir().join(socket_name);
+ let listener = UnixListener::bind(&socket_path)?;
+ Ok(TempSocket {
+ socket_path,
+ _listener: listener,
+ })
+ }
+
+ #[cfg(target_os = "linux")]
+ fn create_abstract_socket() -> Result<TempSocket, io::Error> {
+ let random = Alphanumeric.sample_string(&mut rand::thread_rng(), 10);
+ let socket_name = format!("@unit-client-socket-test-{}.sock", random);
+ let socket_path = PathBuf::from(socket_name);
+ let listener = UnixListener::bind(&socket_path)?;
+ Ok(TempSocket {
+ socket_path,
+ _listener: listener,
+ })
+ }
+}