use std::io; use std::net::ToSocketAddrs; use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpSocket, TcpStream}; use tracing::{error, info, warn}; use crate::client::Error::{IOError, InvalidProtocol, InvalidService}; use crate::protocol::ServiceName; pub struct SharedConfig { pub no_ack: bool, } pub enum Error { IOError(io::Error), InvalidService, InvalidProtocol, } impl From for Error { fn from(value: io::Error) -> Self { IOError(value) } } pub async fn handle_client( cfg: &SharedConfig, mut sock: TcpStream, upstream: &String, service_name: &String, ) -> Result<(), Error> { let remote = match upstream.to_socket_addrs() { Ok(mut addrs) => match addrs.next() { None => { error!("Cannot resolve addr: {}", &upstream); return Ok(()); } Some(sa) => sa, }, Err(why) => { error!("Failed to resolve addr: {}", why); return Ok(()); } }; let remote_sock; if remote.is_ipv4() { remote_sock = TcpSocket::new_v4(); } else { remote_sock = TcpSocket::new_v6(); } let remote_sock = match remote_sock { Ok(rs) => rs, Err(why) => { error!("Failed to create socket: {}", why); return Ok(()); } }; let mut remote_sock = match remote_sock.connect(remote).await { Ok(s) => s, Err(why) => { error!("Failed to connect to upstream: {}", why); return Ok(()); } }; let service_name = ServiceName { service_name, no_ack: cfg.no_ack, }; send_upstream(&mut remote_sock, service_name.to_string().as_bytes()).await?; send_upstream(&mut remote_sock, b"\r\n").await?; remote_sock.flush().await?; if !cfg.no_ack { read_status(&mut remote_sock).await?; } match copy_bidirectional(&mut sock, &mut remote_sock).await { Ok((to_right, to_left)) => { info!( "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}", to_right, to_left ); } Err(why) => { error!("Proxy connection was closed abnormally: {}", why); } }; Ok(()) } async fn read_status(remote_sock: &mut TcpStream) -> Result<(), Error> { let status = remote_sock.read_u8().await?; let mut msg = vec![0_u8; 1024]; let mut i = 0; let mut prev_is_cr = false; loop { let b = remote_sock.read_u8().await?; if i < msg.len() { msg[i] = b; } if b == b'\n' && prev_is_cr { if i < msg.len() { i -= 1; // remove CRLF from reported message string } break; } i += 1; prev_is_cr = b == b'\r'; } let msg = if i < msg.len() { &msg[..i] } else { &msg }; let msg = match std::str::from_utf8(msg) { Ok(s) => s, Err(why) => { warn!( "Failed to decode server message as UTF-8 string, ignore: {}", why ); "???" } }; match status { b'+' => { info!("Upstream service selected successfully: {}", msg); Ok(()) } b'-' => { error!("Upstream responded with negative status: {}", msg); Err(InvalidService) } b => { error!("Invalid status returned from upstream, abort: {:x?}", b); Err(InvalidProtocol) } } } async fn send_upstream(remote: &mut TcpStream, data: &[u8]) -> io::Result<()> { match remote.write_all(data).await { Ok(_) => Ok(()), Err(why) => { error!("Failed to send bytes to upstream: {}", why); Err(why) } } }