use std::io; use std::path::Path; use std::process::exit; use clap::Parser; use tokio::net::TcpListener; use tracing::{debug, error, info}; use crate::arguments::Mode; use crate::server::SharedConfig; mod arguments; mod client; mod config; mod protocol; mod server; fn load_config>(file_path: P) -> config::Root { let content = match std::fs::read_to_string(file_path) { Ok(s) => s, Err(why) => { error!("Failed to read config file: {}", why); exit(1); } }; let cfg: config::Root = match toml::from_str(content.as_str()) { Ok(c) => c, Err(why) => { error!("Failed to parse config file: {}", why); exit(1); } }; cfg } fn validate_config(cfg: &config::Root) { for s in &cfg.service { if s.name.contains("\r\n") { error!("Service name contains CRLF: `{}`", &s.name); exit(1); } } } #[tokio::main] async fn main() -> io::Result<()> { tracing_subscriber::fmt::init(); let args = arguments::Args::parse(); info!("args: {:?}", args); let cfg = load_config(&args.config); validate_config(&cfg); match args.mode { Mode::SERVER => { if cfg.server == None { error!("Config `server` is required in current mode"); exit(1); } if cfg.service.is_empty() { error!("No service is defined"); exit(1); } server_main(cfg).await } Mode::CLIENT => { if cfg.client == None { error!("Config `client` is required in current mode"); exit(1); } client_main(cfg).await } } } async fn server_main(cfg: config::Root) -> io::Result<()> { let server_config = cfg.server.unwrap(); let listener = TcpListener::bind(server_config.listen).await?; info!("Server listening on {:?}", listener); let max_srv_name_length = cfg .service .iter() .map(|s| s.name.len()) .max() .expect("No service was defined in config"); debug!("Max service name length: {}", max_srv_name_length); // this immutable object lives throughout the entire process lifetime, // so we use 'static here for simplicity let shared_config: &'static SharedConfig = Box::leak(Box::new(SharedConfig { max_srv_name_length, services: cfg.service, allow_help: cfg.allow_help, no_ack_extension: server_config.no_ack_extension, })); loop { let (sock, addr) = listener.accept().await?; tokio::spawn(server::handle_client(shared_config, sock, addr)); } } async fn client_server_main( cfg: &'static client::SharedConfig, upstream_addr: &'static String, listener: TcpListener, service_name: &'static String, ) -> io::Result<()> { loop { let (sock, addr) = listener.accept().await?; info!("Client connected: {}", addr); tokio::spawn(client::handle_client( cfg, sock, upstream_addr, service_name, )); } } async fn client_main(cfg: config::Root) -> io::Result<()> { let mut fut_servers = Vec::new(); let client_cfg = cfg.client.unwrap(); let upstream_addr = Box::leak(Box::new(client_cfg.addr)); let shared_config = Box::leak(Box::new(client::SharedConfig { no_ack: client_cfg.no_ack, })); for s in cfg.service { let listener = TcpListener::bind(&s.addr).await?; info!("Server listening on {:?} (service {})", listener, &s.name); let service_name = Box::leak(Box::new(s.name)); fut_servers.push(tokio::spawn(client_server_main( shared_config, upstream_addr, listener, service_name, ))); } for fut in fut_servers { if let Err(why) = fut.await { error!("Failed to join server future: {}", why); } } Ok(()) }