From 143014a91e695106d8383ed173c482b3b4519663 Mon Sep 17 00:00:00 2001 From: Keuin Date: Mon, 4 Sep 2023 01:57:22 +0800 Subject: initial version --- src/server.rs | 132 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 src/server.rs (limited to 'src/server.rs') diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..968bfeb --- /dev/null +++ b/src/server.rs @@ -0,0 +1,132 @@ +use std::net::{SocketAddr, ToSocketAddrs}; + +use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpSocket, TcpStream}; +use tracing::{error, info}; + +use crate::config; + +pub struct SharedConfig { + pub max_srv_name_length: usize, + pub services: Vec, + pub allow_help: bool, +} + +const ERR_SERVICE_NAME_NOT_FOUND: &str = "-Invalid service name\r\n"; +const ERR_SERVICE_NAME_ENCODING: &str = "-Service name is not a UTF-8 string\r\n"; +const OK: &str = "+OK\r\n"; +const SERVICE_HELP: &str = "HELP"; + +pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: SocketAddr) { + info!("New client: {}", addr); + let cap = cfg.max_srv_name_length + 2; + let mut service_name_buf = vec![0_u8; cap]; + let mut i = 0; + let mut ok = false; + while i < service_name_buf.len() { + service_name_buf[i] = match sock.read_u8().await { + Ok(b) => b, + Err(why) => { + error!("Read byte failed: {}", why); + return; + } + }; + if i >= 1 && service_name_buf[i] == b'\n' && service_name_buf[i - 1] == b'\r' { + ok = true; + break; + } + i += 1; + } + if !ok { + error!( + "Service name too long, can't be a valid one: {:?}", + service_name_buf + ); + // Cannot equal to any valid service name + let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await; + return; + } + let wanted_srv_name: &str = match std::str::from_utf8(&service_name_buf.as_slice()[..i - 1]) { + Ok(s) => s, + Err(_) => { + let _ = sock.write_all(ERR_SERVICE_NAME_ENCODING.as_bytes()).await; + return; + } + }; + if cfg.allow_help && wanted_srv_name == SERVICE_HELP { + for s in &cfg.services { + if let Err(_) = sock.write_all(s.name.as_bytes()).await { + return; + } + if let Err(_) = sock.write_all(b"\r\n").await { + return; + } + } + return; + } + let mut srv = None; + for s in &cfg.services { + if s.name == wanted_srv_name { + srv = Some(s); + break; + } + } + if srv == None { + error!("Service name not found: {}", wanted_srv_name); + let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await; + return; + } + let srv = srv.unwrap(); + info!("Client selected service: {}", srv.name); + if let Err(why) = sock.write_all(OK.as_bytes()).await { + error!("Failed to send command to client: {}", why); + return; + } + let remote = match srv.addr.to_socket_addrs() { + Ok(mut addrs) => match addrs.next() { + None => { + error!("Cannot resolve addr: {}", srv.addr); + return; + } + Some(sa) => sa, + }, + Err(why) => { + error!("Failed to resolve addr: {}", why); + return; + } + }; + let remote_sock; + if remote.is_ipv4() { + remote_sock = TcpSocket::new_v4(); + } else { + remote_sock = TcpSocket::new_v6(); + } + let remote_sock = match remote_sock { + Ok(rs) => rs, + Err(why) => { + error!("Failed to create socket: {}", why); + return; + } + }; + let mut remote_sock = match remote_sock.connect(remote).await { + Ok(s) => s, + Err(why) => { + error!( + "Failed to connect to upstream of service {}: {}", + &srv.name, why + ); + return; + } + }; + match copy_bidirectional(&mut sock, &mut remote_sock).await { + Ok((to_right, to_left)) => { + info!( + "Proxy session finished. Bytes: client to service: {}, service to client: {}", + to_right, to_left + ); + } + Err(why) => { + error!("Proxy connection was closed abnormally: {}", why); + } + }; +} -- cgit v1.2.3