diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 133 |
1 files changed, 133 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..b588f2d --- /dev/null +++ b/src/main.rs @@ -0,0 +1,133 @@ +use std::io; +use std::path::Path; +use std::process::exit; + +use clap::Parser; +use tokio::net::TcpListener; +use tracing::{debug, error, info}; + +use crate::arguments::Mode; +use crate::server::SharedConfig; + +mod arguments; +mod client; +mod config; +mod server; + +fn load_config<P: AsRef<Path>>(file_path: P) -> config::Root { + let content = match std::fs::read_to_string(file_path) { + Ok(s) => s, + Err(why) => { + error!("Failed to read config file: {}", why); + exit(1); + } + }; + let cfg: config::Root = match toml::from_str(content.as_str()) { + Ok(c) => c, + Err(why) => { + error!("Failed to parse config file: {}", why); + exit(1); + } + }; + cfg +} + +fn validate_config(cfg: &config::Root) { + for s in &cfg.service { + if s.name.contains("\r\n") { + error!("Service name contains CRLF: `{}`", &s.name); + exit(1); + } + } +} + +#[tokio::main] +async fn main() -> io::Result<()> { + tracing_subscriber::fmt::init(); + + let args = arguments::Args::parse(); + info!("args: {:?}", args); + let cfg = load_config(&args.config); + + validate_config(&cfg); + + match args.mode { + Mode::SERVER => { + if cfg.server == None { + error!("Config `server` is required in current mode"); + exit(1); + } + if cfg.service.is_empty() { + error!("No service is defined"); + exit(1); + } + server_main(cfg).await + } + Mode::CLIENT => { + if cfg.client == None { + error!("Config `client` is required in current mode"); + exit(1); + } + client_main(cfg).await + } + } +} + +async fn server_main(cfg: config::Root) -> io::Result<()> { + let listener = TcpListener::bind(cfg.server.unwrap().listen).await?; + info!("Server listening on {:?}", listener); + let max_srv_name_length = cfg + .service + .iter() + .map(|s| s.name.len()) + .max() + .expect("No service was defined in config"); + debug!("Max service name length: {}", max_srv_name_length); + // this immutable object lives throughout the entire process lifetime, + // so we use 'static here for simplicity + let shared_config: &'static SharedConfig = Box::leak(Box::new(SharedConfig { + max_srv_name_length, + services: cfg.service, + allow_help: cfg.allow_help, + })); + loop { + let (sock, addr) = listener.accept().await?; + tokio::spawn(server::handle_client(shared_config, sock, addr)); + } +} + +async fn client_server_main( + upstream_addr: &'static String, + listener: TcpListener, + service_name: &'static String, +) -> io::Result<()> { + loop { + let (sock, addr) = listener.accept().await?; + info!("Client connected: {}", addr); + tokio::spawn(client::handle_client(sock, upstream_addr, service_name)); + } +} + +async fn client_main(cfg: config::Root) -> io::Result<()> { + let mut fut_servers = Vec::new(); + + let upstream_addr = Box::leak(Box::new(cfg.client.unwrap().addr)); + + for s in cfg.service { + let listener = TcpListener::bind(&s.addr).await?; + info!("Server listening on {:?} (service {})", listener, &s.name); + let service_name = Box::leak(Box::new(s.name)); + fut_servers.push(tokio::spawn(client_server_main( + upstream_addr, + listener, + service_name, + ))); + } + + for fut in fut_servers { + if let Err(why) = fut.await { + error!("Failed to join server future: {}", why); + } + } + Ok(()) +} |