summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs133
1 files changed, 133 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..b588f2d
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,133 @@
+use std::io;
+use std::path::Path;
+use std::process::exit;
+
+use clap::Parser;
+use tokio::net::TcpListener;
+use tracing::{debug, error, info};
+
+use crate::arguments::Mode;
+use crate::server::SharedConfig;
+
+mod arguments;
+mod client;
+mod config;
+mod server;
+
+fn load_config<P: AsRef<Path>>(file_path: P) -> config::Root {
+ let content = match std::fs::read_to_string(file_path) {
+ Ok(s) => s,
+ Err(why) => {
+ error!("Failed to read config file: {}", why);
+ exit(1);
+ }
+ };
+ let cfg: config::Root = match toml::from_str(content.as_str()) {
+ Ok(c) => c,
+ Err(why) => {
+ error!("Failed to parse config file: {}", why);
+ exit(1);
+ }
+ };
+ cfg
+}
+
+fn validate_config(cfg: &config::Root) {
+ for s in &cfg.service {
+ if s.name.contains("\r\n") {
+ error!("Service name contains CRLF: `{}`", &s.name);
+ exit(1);
+ }
+ }
+}
+
+#[tokio::main]
+async fn main() -> io::Result<()> {
+ tracing_subscriber::fmt::init();
+
+ let args = arguments::Args::parse();
+ info!("args: {:?}", args);
+ let cfg = load_config(&args.config);
+
+ validate_config(&cfg);
+
+ match args.mode {
+ Mode::SERVER => {
+ if cfg.server == None {
+ error!("Config `server` is required in current mode");
+ exit(1);
+ }
+ if cfg.service.is_empty() {
+ error!("No service is defined");
+ exit(1);
+ }
+ server_main(cfg).await
+ }
+ Mode::CLIENT => {
+ if cfg.client == None {
+ error!("Config `client` is required in current mode");
+ exit(1);
+ }
+ client_main(cfg).await
+ }
+ }
+}
+
+async fn server_main(cfg: config::Root) -> io::Result<()> {
+ let listener = TcpListener::bind(cfg.server.unwrap().listen).await?;
+ info!("Server listening on {:?}", listener);
+ let max_srv_name_length = cfg
+ .service
+ .iter()
+ .map(|s| s.name.len())
+ .max()
+ .expect("No service was defined in config");
+ debug!("Max service name length: {}", max_srv_name_length);
+ // this immutable object lives throughout the entire process lifetime,
+ // so we use 'static here for simplicity
+ let shared_config: &'static SharedConfig = Box::leak(Box::new(SharedConfig {
+ max_srv_name_length,
+ services: cfg.service,
+ allow_help: cfg.allow_help,
+ }));
+ loop {
+ let (sock, addr) = listener.accept().await?;
+ tokio::spawn(server::handle_client(shared_config, sock, addr));
+ }
+}
+
+async fn client_server_main(
+ upstream_addr: &'static String,
+ listener: TcpListener,
+ service_name: &'static String,
+) -> io::Result<()> {
+ loop {
+ let (sock, addr) = listener.accept().await?;
+ info!("Client connected: {}", addr);
+ tokio::spawn(client::handle_client(sock, upstream_addr, service_name));
+ }
+}
+
+async fn client_main(cfg: config::Root) -> io::Result<()> {
+ let mut fut_servers = Vec::new();
+
+ let upstream_addr = Box::leak(Box::new(cfg.client.unwrap().addr));
+
+ for s in cfg.service {
+ let listener = TcpListener::bind(&s.addr).await?;
+ info!("Server listening on {:?} (service {})", listener, &s.name);
+ let service_name = Box::leak(Box::new(s.name));
+ fut_servers.push(tokio::spawn(client_server_main(
+ upstream_addr,
+ listener,
+ service_name,
+ )));
+ }
+
+ for fut in fut_servers {
+ if let Err(why) = fut.await {
+ error!("Failed to join server future: {}", why);
+ }
+ }
+ Ok(())
+}