diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/arguments.rs | 17 | ||||
-rw-r--r-- | src/client.rs | 112 | ||||
-rw-r--r-- | src/config.rs | 30 | ||||
-rw-r--r-- | src/main.rs | 133 | ||||
-rw-r--r-- | src/server.rs | 132 |
5 files changed, 424 insertions, 0 deletions
diff --git a/src/arguments.rs b/src/arguments.rs new file mode 100644 index 0000000..5f3d8e2 --- /dev/null +++ b/src/arguments.rs @@ -0,0 +1,17 @@ +use clap::{arg, Parser}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + pub mode: Mode, + + /// The config file to use + #[arg(short, long)] + pub config: String, +} + +#[derive(clap::ValueEnum, Clone, Debug)] +pub enum Mode { + SERVER, + CLIENT, +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..1e2edc4 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,112 @@ +use std::io; +use std::net::ToSocketAddrs; + +use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpSocket, TcpStream}; +use tracing::{error, info, warn}; + +pub async fn handle_client( + mut sock: TcpStream, + upstream: &String, + service_name: &String, +) -> io::Result<()> { + let remote = match upstream.to_socket_addrs() { + Ok(mut addrs) => match addrs.next() { + None => { + error!("Cannot resolve addr: {}", &upstream); + return Ok(()); + } + Some(sa) => sa, + }, + Err(why) => { + error!("Failed to resolve addr: {}", why); + return Ok(()); + } + }; + let remote_sock; + if remote.is_ipv4() { + remote_sock = TcpSocket::new_v4(); + } else { + remote_sock = TcpSocket::new_v6(); + } + let remote_sock = match remote_sock { + Ok(rs) => rs, + Err(why) => { + error!("Failed to create socket: {}", why); + return Ok(()); + } + }; + let mut remote_sock = match remote_sock.connect(remote).await { + Ok(s) => s, + Err(why) => { + error!("Failed to connect to upstream: {}", why); + return Ok(()); + } + }; + send_upstream(&mut remote_sock, service_name.as_bytes()).await?; + send_upstream(&mut remote_sock, b"\r\n").await?; + remote_sock.flush().await?; + let status = remote_sock.read_u8().await?; + let mut msg = vec![0_u8; 1024]; + let mut i = 0; + let mut prev_is_cr = false; + loop { + let b = remote_sock.read_u8().await?; + if i < msg.len() { + msg[i] = b; + } + if b == b'\n' && prev_is_cr { + if i < msg.len() { + i -= 2; // remove CRLF from reported message string + } + break; + } + i += 1; + prev_is_cr = b == b'\r'; + } + let msg = match std::str::from_utf8(&msg[..i]) { + Ok(s) => s, + Err(why) => { + warn!( + "Failed to decode server message as UTF-8 string, ignore: {}", + why + ); + "???" + } + }; + match status { + b'+' => { + info!("Upstream service selected successfully: {}", msg); + } + b'-' => { + error!("Upstream responded with negative status: {}", msg); + return Ok(()); + } + b => { + error!("Invalid status returned from upstream, abort: {:x?}", b); + return Ok(()); + } + } + match copy_bidirectional(&mut sock, &mut remote_sock).await { + Ok((to_right, to_left)) => { + info!( + "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}", + to_right, to_left + ); + } + Err(why) => { + error!("Proxy connection was closed abnormally: {}", why); + } + }; + Ok(()) +} + +async fn send_upstream(remote: &mut TcpStream, data: &[u8]) -> io::Result<()> { + match remote.write_all(data).await { + Ok(_) => Ok(()), + Err(why) => { + error!("Failed to send bytes to upstream: {}", why); + Err(why) + } + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..7eee270 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,30 @@ +use serde_derive::{Deserialize, Serialize}; + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Server { + pub listen: String, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Service { + pub name: String, + pub addr: String, // host:port +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Client { + pub addr: String, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Root { + pub server: Option<Server>, + pub client: Option<Client>, + pub service: Vec<Service>, + #[serde(default = "default_true")] + pub allow_help: bool, +} + +fn default_true() -> bool { + true +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..b588f2d --- /dev/null +++ b/src/main.rs @@ -0,0 +1,133 @@ +use std::io; +use std::path::Path; +use std::process::exit; + +use clap::Parser; +use tokio::net::TcpListener; +use tracing::{debug, error, info}; + +use crate::arguments::Mode; +use crate::server::SharedConfig; + +mod arguments; +mod client; +mod config; +mod server; + +fn load_config<P: AsRef<Path>>(file_path: P) -> config::Root { + let content = match std::fs::read_to_string(file_path) { + Ok(s) => s, + Err(why) => { + error!("Failed to read config file: {}", why); + exit(1); + } + }; + let cfg: config::Root = match toml::from_str(content.as_str()) { + Ok(c) => c, + Err(why) => { + error!("Failed to parse config file: {}", why); + exit(1); + } + }; + cfg +} + +fn validate_config(cfg: &config::Root) { + for s in &cfg.service { + if s.name.contains("\r\n") { + error!("Service name contains CRLF: `{}`", &s.name); + exit(1); + } + } +} + +#[tokio::main] +async fn main() -> io::Result<()> { + tracing_subscriber::fmt::init(); + + let args = arguments::Args::parse(); + info!("args: {:?}", args); + let cfg = load_config(&args.config); + + validate_config(&cfg); + + match args.mode { + Mode::SERVER => { + if cfg.server == None { + error!("Config `server` is required in current mode"); + exit(1); + } + if cfg.service.is_empty() { + error!("No service is defined"); + exit(1); + } + server_main(cfg).await + } + Mode::CLIENT => { + if cfg.client == None { + error!("Config `client` is required in current mode"); + exit(1); + } + client_main(cfg).await + } + } +} + +async fn server_main(cfg: config::Root) -> io::Result<()> { + let listener = TcpListener::bind(cfg.server.unwrap().listen).await?; + info!("Server listening on {:?}", listener); + let max_srv_name_length = cfg + .service + .iter() + .map(|s| s.name.len()) + .max() + .expect("No service was defined in config"); + debug!("Max service name length: {}", max_srv_name_length); + // this immutable object lives throughout the entire process lifetime, + // so we use 'static here for simplicity + let shared_config: &'static SharedConfig = Box::leak(Box::new(SharedConfig { + max_srv_name_length, + services: cfg.service, + allow_help: cfg.allow_help, + })); + loop { + let (sock, addr) = listener.accept().await?; + tokio::spawn(server::handle_client(shared_config, sock, addr)); + } +} + +async fn client_server_main( + upstream_addr: &'static String, + listener: TcpListener, + service_name: &'static String, +) -> io::Result<()> { + loop { + let (sock, addr) = listener.accept().await?; + info!("Client connected: {}", addr); + tokio::spawn(client::handle_client(sock, upstream_addr, service_name)); + } +} + +async fn client_main(cfg: config::Root) -> io::Result<()> { + let mut fut_servers = Vec::new(); + + let upstream_addr = Box::leak(Box::new(cfg.client.unwrap().addr)); + + for s in cfg.service { + let listener = TcpListener::bind(&s.addr).await?; + info!("Server listening on {:?} (service {})", listener, &s.name); + let service_name = Box::leak(Box::new(s.name)); + fut_servers.push(tokio::spawn(client_server_main( + upstream_addr, + listener, + service_name, + ))); + } + + for fut in fut_servers { + if let Err(why) = fut.await { + error!("Failed to join server future: {}", why); + } + } + Ok(()) +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..968bfeb --- /dev/null +++ b/src/server.rs @@ -0,0 +1,132 @@ +use std::net::{SocketAddr, ToSocketAddrs}; + +use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpSocket, TcpStream}; +use tracing::{error, info}; + +use crate::config; + +pub struct SharedConfig { + pub max_srv_name_length: usize, + pub services: Vec<config::Service>, + pub allow_help: bool, +} + +const ERR_SERVICE_NAME_NOT_FOUND: &str = "-Invalid service name\r\n"; +const ERR_SERVICE_NAME_ENCODING: &str = "-Service name is not a UTF-8 string\r\n"; +const OK: &str = "+OK\r\n"; +const SERVICE_HELP: &str = "HELP"; + +pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: SocketAddr) { + info!("New client: {}", addr); + let cap = cfg.max_srv_name_length + 2; + let mut service_name_buf = vec![0_u8; cap]; + let mut i = 0; + let mut ok = false; + while i < service_name_buf.len() { + service_name_buf[i] = match sock.read_u8().await { + Ok(b) => b, + Err(why) => { + error!("Read byte failed: {}", why); + return; + } + }; + if i >= 1 && service_name_buf[i] == b'\n' && service_name_buf[i - 1] == b'\r' { + ok = true; + break; + } + i += 1; + } + if !ok { + error!( + "Service name too long, can't be a valid one: {:?}", + service_name_buf + ); + // Cannot equal to any valid service name + let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await; + return; + } + let wanted_srv_name: &str = match std::str::from_utf8(&service_name_buf.as_slice()[..i - 1]) { + Ok(s) => s, + Err(_) => { + let _ = sock.write_all(ERR_SERVICE_NAME_ENCODING.as_bytes()).await; + return; + } + }; + if cfg.allow_help && wanted_srv_name == SERVICE_HELP { + for s in &cfg.services { + if let Err(_) = sock.write_all(s.name.as_bytes()).await { + return; + } + if let Err(_) = sock.write_all(b"\r\n").await { + return; + } + } + return; + } + let mut srv = None; + for s in &cfg.services { + if s.name == wanted_srv_name { + srv = Some(s); + break; + } + } + if srv == None { + error!("Service name not found: {}", wanted_srv_name); + let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await; + return; + } + let srv = srv.unwrap(); + info!("Client selected service: {}", srv.name); + if let Err(why) = sock.write_all(OK.as_bytes()).await { + error!("Failed to send command to client: {}", why); + return; + } + let remote = match srv.addr.to_socket_addrs() { + Ok(mut addrs) => match addrs.next() { + None => { + error!("Cannot resolve addr: {}", srv.addr); + return; + } + Some(sa) => sa, + }, + Err(why) => { + error!("Failed to resolve addr: {}", why); + return; + } + }; + let remote_sock; + if remote.is_ipv4() { + remote_sock = TcpSocket::new_v4(); + } else { + remote_sock = TcpSocket::new_v6(); + } + let remote_sock = match remote_sock { + Ok(rs) => rs, + Err(why) => { + error!("Failed to create socket: {}", why); + return; + } + }; + let mut remote_sock = match remote_sock.connect(remote).await { + Ok(s) => s, + Err(why) => { + error!( + "Failed to connect to upstream of service {}: {}", + &srv.name, why + ); + return; + } + }; + match copy_bidirectional(&mut sock, &mut remote_sock).await { + Ok((to_right, to_left)) => { + info!( + "Proxy session finished. Bytes: client to service: {}, service to client: {}", + to_right, to_left + ); + } + Err(why) => { + error!("Proxy connection was closed abnormally: {}", why); + } + }; +} |