use std::net::{SocketAddr, ToSocketAddrs}; use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpSocket, TcpStream}; use tracing::{error, info}; use crate::config; use crate::protocol::{ServiceName, PREFIX}; pub struct SharedConfig { pub max_srv_name_length: usize, pub services: Vec, pub allow_help: bool, pub no_ack_extension: bool, } const ERR_SERVICE_NAME_NOT_FOUND: &str = "-Invalid service name\r\n"; const ERR_SERVICE_NAME_ENCODING: &str = "-Service name is not a UTF-8 string\r\n"; const OK: &str = "+OK\r\n"; const SERVICE_HELP: &str = "HELP"; pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: SocketAddr) { info!("New client: {}", addr); let cap = cfg.max_srv_name_length + 2 + if cfg.no_ack_extension { PREFIX.len() } else { 0 }; let mut service_name_buf = vec![0_u8; cap]; let mut i = 0; let mut ok = false; while i < service_name_buf.len() { service_name_buf[i] = match sock.read_u8().await { Ok(b) => b, Err(why) => { error!("Read byte failed: {}", why); return; } }; if i >= 1 && service_name_buf[i] == b'\n' && service_name_buf[i - 1] == b'\r' { ok = true; break; } i += 1; } if !ok { error!( "Service name too long, can't be a valid one: {:?}", service_name_buf ); // Cannot equal to any valid service name let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await; return; } let wanted_srv_name: &str = match std::str::from_utf8(&service_name_buf.as_slice()[..i - 1]) { Ok(s) => s, Err(_) => { let _ = sock.write_all(ERR_SERVICE_NAME_ENCODING.as_bytes()).await; return; } }; if cfg.allow_help && wanted_srv_name == SERVICE_HELP { for s in &cfg.services { if let Err(_) = sock.write_all(s.name.as_bytes()).await { return; } if let Err(_) = sock.write_all(b"\r\n").await { return; } } return; } let wanted_srv_name = if cfg.no_ack_extension { ServiceName::from(wanted_srv_name) } else { ServiceName { service_name: wanted_srv_name, no_ack: false, } }; let mut srv = None; for s in &cfg.services { if s.name == wanted_srv_name.service_name { srv = Some(s); break; } } if srv == None { error!("Service name not found: {}", wanted_srv_name.to_string()); let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await; return; } let srv = srv.unwrap(); info!("Client selected service: {}", srv.name); if !wanted_srv_name.no_ack { if let Err(why) = sock.write_all(OK.as_bytes()).await { error!("Failed to send command to client: {}", why); return; } } let remote = match srv.addr.to_socket_addrs() { Ok(mut addrs) => match addrs.next() { None => { error!("Cannot resolve addr: {}", srv.addr); return; } Some(sa) => sa, }, Err(why) => { error!("Failed to resolve addr: {}", why); return; } }; let remote_sock; if remote.is_ipv4() { remote_sock = TcpSocket::new_v4(); } else { remote_sock = TcpSocket::new_v6(); } let remote_sock = match remote_sock { Ok(rs) => rs, Err(why) => { error!("Failed to create socket: {}", why); return; } }; let mut remote_sock = match remote_sock.connect(remote).await { Ok(s) => s, Err(why) => { error!( "Failed to connect to upstream of service {}: {}", &srv.name, why ); return; } }; match copy_bidirectional(&mut sock, &mut remote_sock).await { Ok((to_right, to_left)) => { info!( "Proxy session finished. Bytes: client to service: {}, service to client: {}", to_right, to_left ); } Err(why) => { error!("Proxy connection was closed abnormally: {}", why); } }; }