Skip to content
131 changes: 118 additions & 13 deletions src/config_observer.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
use std::fs;
use std::path::PathBuf;
use directories::UserDirs;

pub fn expand_tilde(path: &str) -> PathBuf {
if path == "~" {
if let Some(home) = UserDirs::new().map(|d| d.home_dir().to_path_buf()) {
return home;
}
} else if let Some(rest) = path.strip_prefix("~/")
&& let Some(home) = UserDirs::new().map(|d| d.home_dir().to_path_buf()) {
return home.join(rest);
}
PathBuf::from(path)
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SshHost {
pub alias: String,
pub hostname: String,
pub user: Option<String>,
pub port: Option<u16>,
pub identity_file: Option<String>,
}

pub fn get_default_config_path() -> Option<std::path::PathBuf> {
Expand Down Expand Up @@ -55,6 +69,7 @@ pub fn parse_ssh_config(content: &str) -> Vec<SshHost> {
hostname: String::new(),
user: None,
port: None,
identity_file: None,
});
}
"hostname" => {
Expand All @@ -68,10 +83,14 @@ pub fn parse_ssh_config(content: &str) -> Vec<SshHost> {
}
}
"port" => {
if let Some(ref mut host) = current_host {
if let Ok(p) = value.parse::<u16>() {
if let Some(ref mut host) = current_host
&& let Ok(p) = value.parse::<u16>() {
host.port = Some(p);
}
}
"identityfile" => {
if let Some(ref mut host) = current_host {
host.identity_file = Some(value.to_string());
}
}
_ => {}
Expand All @@ -88,7 +107,6 @@ pub fn parse_ssh_config(content: &str) -> Vec<SshHost> {

pub fn add_host_to_config(host: &SshHost) -> anyhow::Result<()> {
let path = get_default_config_path().ok_or_else(|| anyhow::anyhow!("Could not find SSH config path"))?;

if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
Expand All @@ -109,23 +127,33 @@ pub fn add_host_to_config(host: &SshHost) -> anyhow::Result<()> {
host.alias.clone()
};

let entry = format!(
let mut entry = format!(
"\nHost {}\n HostName {}\n User {}\n Port {}\n",
alias_quoted,
host.hostname,
host.user.as_deref().unwrap_or("root"),
host.port.unwrap_or(22)
);

if let Some(ref id_file) = host.identity_file {
let id_file_quoted = if id_file.contains(' ') {
format!("\"{}\"", id_file)
} else {
id_file.clone()
};
entry.push_str(&format!(" IdentityFile {}\n", id_file_quoted));
}

content.push_str(&entry);
std::fs::write(path, content)?;
let tmp_path = path.with_extension("tmp");
std::fs::write(&tmp_path, &content)?;
std::fs::rename(tmp_path, path)?;
Ok(())
}

pub fn delete_host_from_config(alias: &str) -> anyhow::Result<()> {
let path = get_default_config_path().ok_or_else(|| anyhow::anyhow!("No config path"))?;
if !path.exists() { return Ok(()); }

let content = std::fs::read_to_string(&path)?;
let mut new_lines = Vec::new();
let mut skip = false;
Expand All @@ -138,27 +166,24 @@ pub fn delete_host_from_config(alias: &str) -> anyhow::Result<()> {
if val.starts_with('"') && val.ends_with('"') && val.len() >= 2 {
val = &val[1..val.len()-1];
}

if val == target_alias {
skip = true;
continue;
} else {
skip = false;
}
}

if skip && (line.starts_with(' ') || line.starts_with('\t') || line.trim().is_empty()) {
continue;
}

if skip {
skip = false;
}

new_lines.push(line);
}

std::fs::write(path, new_lines.join("\n"))?;
let tmp_path = path.with_extension("tmp");
std::fs::write(&tmp_path, new_lines.join("\n"))?;
std::fs::rename(tmp_path, path)?;
Ok(())
}

Expand All @@ -184,4 +209,84 @@ mod tests {
assert_eq!(hosts.len(), 1);
assert_eq!(hosts[0].alias, "My Server");
}
}

#[test]
fn test_parse_ssh_config_with_identity_file() {
let config = "Host my-server\n HostName 1.2.3.4\n User root\n Port 22\n IdentityFile ~/.ssh/id_ed25519";
let hosts = parse_ssh_config(config);
assert_eq!(hosts.len(), 1);
assert_eq!(hosts[0].identity_file, Some("~/.ssh/id_ed25519".to_string()));
}

#[test]
fn test_parse_ssh_config_with_quoted_identity_file() {
let config = "Host my-server\n HostName 1.2.3.4\n User root\n IdentityFile \"/home/user/my keys/id_ed25519\"";
let hosts = parse_ssh_config(config);
assert_eq!(hosts.len(), 1);
assert_eq!(hosts[0].identity_file, Some("/home/user/my keys/id_ed25519".to_string()));
}

#[test]
fn test_add_host_to_config_emits_identity_file() {
let host = SshHost {
alias: "test-host".to_string(),
hostname: "192.168.1.1".to_string(),
user: Some("admin".to_string()),
port: Some(22),
identity_file: Some("~/.ssh/id_ed25519".to_string()),
};
let alias_quoted = if host.alias.contains(' ') {
format!("\"{}\"", host.alias)
} else {
host.alias.clone()
};
let mut entry = format!(
"\nHost {}\n HostName {}\n User {}\n Port {}\n",
alias_quoted,
host.hostname,
host.user.as_deref().unwrap_or("root"),
host.port.unwrap_or(22)
);
if let Some(ref id_file) = host.identity_file {
let id_file_quoted = if id_file.contains(' ') {
format!("\"{}\"", id_file)
} else {
id_file.clone()
};
entry.push_str(&format!(" IdentityFile {}\n", id_file_quoted));
}
assert!(entry.contains("IdentityFile ~/.ssh/id_ed25519"));
let hosts = parse_ssh_config(&entry);
assert_eq!(hosts.len(), 1);
assert_eq!(hosts[0].identity_file, Some("~/.ssh/id_ed25519".to_string()));
}

#[test]
fn test_add_host_to_config_quotes_identity_file_with_spaces() {
let host = SshHost {
alias: "spaced-host".to_string(),
hostname: "10.0.0.1".to_string(),
user: Some("user".to_string()),
port: Some(22),
identity_file: Some("/home/user/my keys/id_rsa".to_string()),
};
let mut entry = format!(
"\nHost {}\n HostName {}\n User {}\n Port {}\n",
host.alias, host.hostname,
host.user.as_deref().unwrap_or("root"),
host.port.unwrap_or(22)
);
if let Some(ref id_file) = host.identity_file {
let id_file_quoted = if id_file.contains(' ') {
format!("\"{}\"", id_file)
} else {
id_file.clone()
};
entry.push_str(&format!(" IdentityFile {}\n", id_file_quoted));
}
assert!(entry.contains("IdentityFile \"/home/user/my keys/id_rsa\""));
let hosts = parse_ssh_config(&entry);
assert_eq!(hosts.len(), 1);
assert_eq!(hosts[0].identity_file, Some("/home/user/my keys/id_rsa".to_string()));
}
}
4 changes: 1 addition & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ fn log_debug(msg: &str) {
#[tokio::main]
async fn main() {
let _args: Vec<String> = std::env::args().collect();

if let Ok(alias) = std::env::var("RUSTMIUS_ASKPASS_ALIAS") {
log_debug(&format!("AskPass triggered for alias: {}", alias));
if let Ok(keyring) = oo7::Keyring::new().await {
Expand All @@ -32,7 +31,6 @@ async fn main() {
&& let Ok(password) = item.secret().await
&& let Ok(pass_str) = std::str::from_utf8(&password) {
log_debug("Password retrieved successfully, sending to SSH");

print!("{}", pass_str);
std::process::exit(0);
}
Expand All @@ -49,4 +47,4 @@ async fn main() {

app.connect_activate(build_ui);
app.run_with_args::<&str>(&[]);
}
}
59 changes: 38 additions & 21 deletions src/sftp_engine.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use ssh2::Session;
use std::net::TcpStream;
use crate::config_observer::SshHost;
use std::net::{TcpStream, ToSocketAddrs};
use std::time::Duration;
use crate::config_observer::{SshHost, expand_tilde};
use std::path::Path;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex, OnceLock};
Expand All @@ -14,7 +15,7 @@ pub struct RemoteFile {
}

pub struct ActiveSession {
_sess: Session, // Keep the session alive for the Sftp pointer
_sess: Session,
pub sftp: ssh2::Sftp,
}

Expand All @@ -25,48 +26,66 @@ fn get_session_pool() -> &'static Mutex<HashMap<String, Arc<ActiveSession>>> {

fn get_or_connect_sftp(host: &SshHost, password: &Option<String>) -> anyhow::Result<Arc<ActiveSession>> {
let host_key = format!("{}@{}", host.user.as_deref().unwrap_or("root"), host.hostname);

if let Ok(mut pool) = get_session_pool().lock() {
if let Some(active) = pool.get(&host_key) {
// Check if connection is still alive using a basic stat
if let Ok(mut pool) = get_session_pool().lock()
&& let Some(active) = pool.get(&host_key) {
if active.sftp.stat(Path::new(".")).is_ok() {
return Ok(active.clone());
} else {
// Connection died, remove it to force reconnect
pool.remove(&host_key);
}
}
}

let port = host.port.unwrap_or(22);
let tcp = TcpStream::connect(format!("{}:{}", host.hostname, port))?;
let addrs = format!("{}:{}", host.hostname, port).to_socket_addrs()?;
let mut tcp_opt = None;
for addr in addrs {
if let Ok(stream) = TcpStream::connect_timeout(&addr, Duration::from_secs(5)) {
tcp_opt = Some(stream);
break;
}
}
let tcp = tcp_opt.ok_or_else(|| anyhow::anyhow!("Connection timeout to {}", host.hostname))?;
let mut sess = Session::new()?;
sess.set_tcp_stream(tcp);
sess.handshake()?;

let user = host.user.as_deref().unwrap_or("root");
if let Some(pass) = password {
sess.userauth_password(user, pass)?;
} else {
sess.userauth_agent(user)?;
let mut authenticated = false;
if let Some(ref key_path) = host.identity_file {
let path = expand_tilde(key_path);
if sess.userauth_pubkey_file(user, None, &path, None).is_ok() {
println!("[DEBUG] SFTP connected to {} via Configure SSH Key ({})", host.hostname, key_path);
authenticated = true;
}
}
if !authenticated
&& sess.userauth_agent(user).is_ok() {
println!("[DEBUG] SFTP connected to {} via SSH Agent", host.hostname);
authenticated = true;
}

if !authenticated
&& let Some(pass) = password
&& sess.userauth_password(user, pass).is_ok() {
println!("[DEBUG] SFTP connected to {} via Password", host.hostname);
authenticated = true;
}
if !authenticated {
return Err(anyhow::anyhow!("Authentication failed (tried key, password, and agent)"));
}

let sftp = sess.sftp()?;

let active = Arc::new(ActiveSession { _sess: sess, sftp });

if let Ok(mut pool) = get_session_pool().lock() {
pool.insert(host_key, active.clone());
}

Ok(active)
}

pub async fn list_files(host: SshHost, password: Option<String>, path: String) -> anyhow::Result<Vec<RemoteFile>> {
tokio::task::spawn_blocking(move || {
let active = get_or_connect_sftp(&host, &password)?;
let dir = active.sftp.readdir(Path::new(&path))?;

let mut files = Vec::new();
for (path, stat) in dir {
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
Expand All @@ -77,7 +96,6 @@ pub async fn list_files(host: SshHost, password: Option<String>, path: String) -
});
}
}

files.sort_by(|a, b| {
if a.is_dir != b.is_dir {
b.is_dir.cmp(&a.is_dir)
Expand Down Expand Up @@ -132,7 +150,6 @@ pub async fn upload_file(host: SshHost, password: Option<String>, local_path: St
let active = get_or_connect_sftp(&host, &password)?;
let mut local_file = std::fs::File::open(local_path)?;
let mut remote_file = active.sftp.create(Path::new(&remote_path))?;

let mut buffer = [0; 16384];
while let Ok(n) = local_file.read(&mut buffer) {
if n == 0 { break; }
Expand Down Expand Up @@ -160,4 +177,4 @@ pub fn download_file_sync(host: SshHost, password: Option<String>, remote_path:
local_file.write_all(&buffer[..n])?;
}
Ok(())
}
}
Loading
Loading