From 0632d7b0a98707a9b6e37db0b9ce7d620eb47348 Mon Sep 17 00:00:00 2001 From: Steppy Date: Tue, 4 Feb 2025 23:09:02 +0100 Subject: [PATCH] Handle input in environment --- src/environment.rs | 13 ++++-- src/main.rs | 106 +++++++++++++++++++++++++-------------------- 2 files changed, 69 insertions(+), 50 deletions(-) diff --git a/src/environment.rs b/src/environment.rs index a51a2bc..454e09f 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -1,18 +1,19 @@ -use std::env; use std::env::VarError; use std::ffi::{OsStr, OsString}; use std::path::PathBuf; +use std::{env, io}; pub trait Environment { fn args_os(&self) -> Vec; fn var(&self, key: K) -> Result where K: AsRef; - fn set_var(&self, key: K, value: V) + fn set_var(&mut self, key: K, value: V) where K: AsRef, V: AsRef; fn get_home_directory(&self) -> Option; + fn read_line(&mut self) -> Result; } #[derive(Debug, Default)] @@ -30,7 +31,7 @@ impl Environment for Prod { env::var(key) } - fn set_var(&self, key: K, value: V) + fn set_var(&mut self, key: K, value: V) where K: AsRef, V: AsRef @@ -41,4 +42,10 @@ impl Environment for Prod { fn get_home_directory(&self) -> Option { homedir::my_home().ok().flatten() } + + fn read_line(&mut self) -> Result { + let mut buffer = String::new(); + io::stdin().read_line(&mut buffer)?; + Ok(buffer.trim().to_string()) + } } diff --git a/src/main.rs b/src/main.rs index ad86667..067d368 100644 --- a/src/main.rs +++ b/src/main.rs @@ -141,24 +141,6 @@ enum OldVersionPolicy { Delete, } -//TODO IO would also need to be handled by the environment - -#[macro_export] -macro_rules! input { - ($prompt: tt) => {{ - print!($prompt); - io::stdout().flush().expect("failed to flush stdout"); - let mut buf = String::new(); - io::stdin() - .read_line(&mut buf) - .expect("failed to read stdin"); - buf.trim().to_string() - }}; - () => { - input!() - }; -} - #[derive(Debug, Default)] pub struct Application { pub environment: E, @@ -168,12 +150,14 @@ impl Application where E: Environment, { - pub fn run(&self) -> Result<(), String> { + pub fn run(&mut self) -> Result<(), String> { let args = Args::try_parse_from(self.environment.args_os()).map_err(|e| e.to_string())?; self.run_with_args(args) } - pub fn run_with_args(&self, args: Args) -> Result<(), String> { + pub fn run_with_args(&mut self, args: Args) -> Result<(), String> { + let env = &mut self.environment; + let logger = Logger { //all the below options are conflicting with each other so an if else is fine level: if args.quiet { @@ -216,7 +200,6 @@ where } => { Self::require_non_empty_servers(&servers)?; Self::require_non_empty(&files, "files to upload")?; - self.start_ssh_agent(&logger)?; //resolve file server let file_server = match file_server { @@ -230,6 +213,8 @@ where None => None, }; + self.start_ssh_agent(&logger)?; + //make sure files exist match &file_server { Some(file_server) => match &file_server.address { @@ -448,14 +433,9 @@ where } } - if !no_confirm { - match input!("Continue? [Y|n] ").to_lowercase().as_str() { - "n" | "no" => { - log!(logger, "Aborting..."); - return Ok(()); - } - _ => {} - } + if !no_confirm && !self.confirm("Continue?", true) { + log!(logger, "Aborting..."); + return Ok(()); } for server_actions in actions { @@ -557,7 +537,8 @@ where let download_directory = match download_directory { Some(download_directory) => download_directory, None => { - let home_dir = self.get_home_directory() + let home_dir = self + .get_home_directory() .map_err(|e| format!("Missing download-directory: {e}"))?; home_dir.join("Downloads") } @@ -588,14 +569,8 @@ where download_directory.to_string_lossy() ); - if !args.quiet { - match input!("{duplication_notification}. Do you want to replace it? [N|y] ") - .to_lowercase() - .as_str() - { - "y" | "yes" => break 'duplicate_check, - _ => {} - } + if !args.quiet && self.confirm(format!("{duplication_notification}. Do you want to replace it?"), false) { + break 'duplicate_check; } return Err(format!( @@ -677,8 +652,8 @@ where Ok(()) } - fn start_ssh_agent(&self, logger: &Logger) -> Result<(), String> { - let env = &self.environment; + fn start_ssh_agent(&mut self, logger: &Logger) -> Result<(), String> { + let env = &mut self.environment; //start the ssh agent let agent_output = ShellCmd::new("ssh-agent") @@ -711,9 +686,36 @@ where .map_err(|_| format!("Missing environment variable {}", SERVERS_ENV_VAR)) .and_then(|value| parse_server_configuration(&value, || self.get_home_directory())) } - + fn get_home_directory(&self) -> Result { - self.environment.get_home_directory().ok_or("Failed to find your home directory".to_string()) + self + .environment + .get_home_directory() + .ok_or("Failed to find your home directory".to_string()) + } + + fn confirm(&mut self, prompt: S, default_value: bool) -> bool + where + S: ToString, + { + loop { + print!( + "{}[{}]", + prompt.to_string(), + if default_value { "Y|n" } else { "y|N" } + ); + io::stdout().flush().expect("failed to flush stdout"); + let line = self + .environment + .read_line() + .expect("Failed to read console input"); + match line.to_lowercase().as_str() { + "" => return default_value, + "y" | "yes" => return true, + "n" | "no" => return false, + _ => println!("Invalid input, please choose one of the provided options"), + } + } } } @@ -736,12 +738,22 @@ fn osstring_from_ssh_output(output: Vec) -> OsString { } } -fn parse_server_configuration(config_str: &str, get_home_directory: F) -> Result, String> where F: Fn() -> Result { +fn parse_server_configuration( + config_str: &str, + get_home_directory: F, +) -> Result, String> +where + F: Fn() -> Result, +{ config_str .split(',') .map(|server_entry| { - Server::from_str(server_entry, RelativeLocalPathAnker::Home, &get_home_directory) - .map_err(|e| format!("Invalid server entry '{server_entry}': {e}")) + Server::from_str( + server_entry, + RelativeLocalPathAnker::Home, + &get_home_directory, + ) + .map_err(|e| format!("Invalid server entry '{server_entry}': {e}")) }) .collect() } @@ -755,8 +767,8 @@ mod test { #[test] fn test_parse_server_configuration() { - let servers = - parse_server_configuration("foo:bar,.:fizz/buzz", || Ok(PathBuf::from("/test"))).expect("valid server configuration"); + let servers = parse_server_configuration("foo:bar,.:fizz/buzz", || Ok(PathBuf::from("/test"))) + .expect("valid server configuration"); assert_eq!( vec![ Server {