From 6c7e718748a709de554640b2e4219bbee9f5b9c5 Mon Sep 17 00:00:00 2001 From: Steppy Date: Tue, 4 Feb 2025 23:43:11 +0100 Subject: [PATCH] Fix ServerReferences using invalid parser --- src/main.rs | 30 +++++++++++++++++++++--------- src/server.rs | 36 ++++++++++++++++-------------------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/src/main.rs b/src/main.rs index 067d368..cb964b9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,6 @@ use std::hash::Hash; use std::io::Write; use std::iter::once; use std::path::{Path, PathBuf}; -use std::str::FromStr; use std::{env, fs, io}; const SERVERS_ENV_VAR: &str = "MSSH_SERVERS"; @@ -49,9 +48,8 @@ pub struct Args { #[command(subcommand)] command: Command, /// The ssh names and optionally home directories of the servers to perform the action on - //TODO from_str always uses Prod environment -> handwrite that section - #[arg(num_args = 0.., value_parser = ServerReference::from_str)] - servers: Vec, + #[arg(num_args = 0..)] + servers: Vec, /// How verbose logging output should be #[arg(long, default_value = "info", conflicts_with_all = ["quiet", "info"])] log_level: LogLevel, @@ -75,7 +73,7 @@ enum Command { /// When this option is set, the file path must be absolute, or relative to the server directory. /// The upload-directory has no influence on where the file will be taken from. #[arg(short = 'S', long)] - file_server: Option, + file_server: Option, /// How to handle older versions of the file #[arg(short = 'a', long, default_value = "delete", default_missing_value = "archive", num_args = 0..=1)] old_version_policy: OldVersionPolicy, @@ -156,7 +154,7 @@ where } pub fn run_with_args(&mut self, args: Args) -> Result<(), String> { - let env = &mut self.environment; + let _env = &mut self.environment; let logger = Logger { //all the below options are conflicting with each other so an if else is fine @@ -172,7 +170,13 @@ where let mut configured_servers = LazyCell::new(|| self.parse_server_configuration_from_env()); let servers = args .servers - .iter() + .into_iter() + .map(|ref_str| { + ServerReference::from_str(&ref_str, || self.get_home_directory()) + .map_err(|e| format!("Invalid server reference '{ref_str}': {e}")) + }) + .collect::, _>>()? + .into_iter() .map(|server_reference| { let server_identifier = server_reference.get_identifier(); server_reference @@ -203,7 +207,10 @@ where //resolve file server let file_server = match file_server { - Some(server_reference) => { + Some(ref_str) => { + let server_reference = + ServerReference::from_str(&ref_str, || self.get_home_directory()) + .map_err(|e| format!("Invalid file-server reference '{ref_str}': {e}"))?; let file_server_identifier = server_reference.get_identifier().to_string(); let server = server_reference.try_resolve_lazy(&mut configured_servers) .map_err(|e| format!("Can't resolve server directory for file-server '{file_server_identifier}': {e}"))? @@ -569,7 +576,12 @@ where download_directory.to_string_lossy() ); - if !args.quiet && self.confirm(format!("{duplication_notification}. Do you want to replace it?"), false) { + if !args.quiet + && self.confirm( + format!("{duplication_notification}. Do you want to replace it?"), + false, + ) + { break 'duplicate_check; } diff --git a/src/server.rs b/src/server.rs index cf240a3..7c62268 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,3 @@ -use crate::environment::{Environment, Prod}; use std::cell::LazyCell; use std::error::Error; use std::fmt::{Display, Formatter}; @@ -6,7 +5,6 @@ use std::fs; use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::path::PathBuf; -use std::str::FromStr; #[derive(Debug, Clone)] pub enum ServerReference { @@ -15,6 +13,19 @@ pub enum ServerReference { } impl ServerReference { + pub fn from_str(s: &str, get_home_directory: F) -> Result + where + F: FnOnce() -> Result, + { + Server::from_str( + s, + RelativeLocalPathAnker::CurrentDirectory, + get_home_directory, + ) + .map(Self::Resolved) + .or_else(|_| Ok(Self::Identifier(s.to_string()))) + } + pub fn get_identifier(&self) -> &str { match self { ServerReference::Resolved(server) => server.address.identifier(), @@ -68,20 +79,6 @@ impl ServerReference { } } -impl FromStr for ServerReference { - type Err = ServerReferenceParseError; - - fn from_str(s: &str) -> Result { - Server::from_str(s, RelativeLocalPathAnker::CurrentDirectory, || { - Prod::default() - .get_home_directory() - .ok_or("missing home directory".to_string()) - }) - .map(Self::Resolved) - .or_else(|_| Ok(Self::Identifier(s.to_string()))) - } -} - impl PartialEq for ServerReference { fn eq(&self, other: &Self) -> bool { self.get_identifier() == other.get_identifier() @@ -136,7 +133,7 @@ impl Server { get_home_directory: F, ) -> Result where - F: Fn() -> Result, + F: FnOnce() -> Result, { s.split_once(':') .ok_or(ServerParseError::MissingServerDirectory) @@ -253,13 +250,12 @@ impl Error for ServerParseError {} mod test_server_reference { use crate::server::{Server, ServerAddress, ServerReference}; use std::path::PathBuf; - use std::str::FromStr; #[test] fn test_from_str() { assert_eq!( ServerReference::Identifier("foo".to_string()), - ServerReference::from_str("foo").unwrap() + ServerReference::from_str("foo", || panic!("shouldn't be called")).unwrap() ); assert_eq!( ServerReference::Resolved(Server { @@ -268,7 +264,7 @@ mod test_server_reference { }, server_directory_path: PathBuf::from("server/creative2") }), - ServerReference::from_str("crea:server/creative2").unwrap() + ServerReference::from_str("crea:server/creative2", || panic!("shouldn't be called")).unwrap() ); } }