summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/arguments.rs17
-rw-r--r--src/client.rs112
-rw-r--r--src/config.rs30
-rw-r--r--src/main.rs133
-rw-r--r--src/server.rs132
5 files changed, 424 insertions, 0 deletions
diff --git a/src/arguments.rs b/src/arguments.rs
new file mode 100644
index 0000000..5f3d8e2
--- /dev/null
+++ b/src/arguments.rs
@@ -0,0 +1,17 @@
+use clap::{arg, Parser};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+pub struct Args {
+ pub mode: Mode,
+
+ /// The config file to use
+ #[arg(short, long)]
+ pub config: String,
+}
+
+#[derive(clap::ValueEnum, Clone, Debug)]
+pub enum Mode {
+ SERVER,
+ CLIENT,
+}
diff --git a/src/client.rs b/src/client.rs
new file mode 100644
index 0000000..1e2edc4
--- /dev/null
+++ b/src/client.rs
@@ -0,0 +1,112 @@
+use std::io;
+use std::net::ToSocketAddrs;
+
+use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt};
+use tokio::net::{TcpSocket, TcpStream};
+use tracing::{error, info, warn};
+
+pub async fn handle_client(
+ mut sock: TcpStream,
+ upstream: &String,
+ service_name: &String,
+) -> io::Result<()> {
+ let remote = match upstream.to_socket_addrs() {
+ Ok(mut addrs) => match addrs.next() {
+ None => {
+ error!("Cannot resolve addr: {}", &upstream);
+ return Ok(());
+ }
+ Some(sa) => sa,
+ },
+ Err(why) => {
+ error!("Failed to resolve addr: {}", why);
+ return Ok(());
+ }
+ };
+ let remote_sock;
+ if remote.is_ipv4() {
+ remote_sock = TcpSocket::new_v4();
+ } else {
+ remote_sock = TcpSocket::new_v6();
+ }
+ let remote_sock = match remote_sock {
+ Ok(rs) => rs,
+ Err(why) => {
+ error!("Failed to create socket: {}", why);
+ return Ok(());
+ }
+ };
+ let mut remote_sock = match remote_sock.connect(remote).await {
+ Ok(s) => s,
+ Err(why) => {
+ error!("Failed to connect to upstream: {}", why);
+ return Ok(());
+ }
+ };
+ send_upstream(&mut remote_sock, service_name.as_bytes()).await?;
+ send_upstream(&mut remote_sock, b"\r\n").await?;
+ remote_sock.flush().await?;
+ let status = remote_sock.read_u8().await?;
+ let mut msg = vec![0_u8; 1024];
+ let mut i = 0;
+ let mut prev_is_cr = false;
+ loop {
+ let b = remote_sock.read_u8().await?;
+ if i < msg.len() {
+ msg[i] = b;
+ }
+ if b == b'\n' && prev_is_cr {
+ if i < msg.len() {
+ i -= 2; // remove CRLF from reported message string
+ }
+ break;
+ }
+ i += 1;
+ prev_is_cr = b == b'\r';
+ }
+ let msg = match std::str::from_utf8(&msg[..i]) {
+ Ok(s) => s,
+ Err(why) => {
+ warn!(
+ "Failed to decode server message as UTF-8 string, ignore: {}",
+ why
+ );
+ "???"
+ }
+ };
+ match status {
+ b'+' => {
+ info!("Upstream service selected successfully: {}", msg);
+ }
+ b'-' => {
+ error!("Upstream responded with negative status: {}", msg);
+ return Ok(());
+ }
+ b => {
+ error!("Invalid status returned from upstream, abort: {:x?}", b);
+ return Ok(());
+ }
+ }
+ match copy_bidirectional(&mut sock, &mut remote_sock).await {
+ Ok((to_right, to_left)) => {
+ info!(
+ "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}",
+ to_right, to_left
+ );
+ }
+ Err(why) => {
+ error!("Proxy connection was closed abnormally: {}", why);
+ }
+ };
+ Ok(())
+}
+
+async fn send_upstream(remote: &mut TcpStream, data: &[u8]) -> io::Result<()> {
+ match remote.write_all(data).await {
+ Ok(_) => Ok(()),
+ Err(why) => {
+ error!("Failed to send bytes to upstream: {}", why);
+ Err(why)
+ }
+ }
+}
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 0000000..7eee270
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,30 @@
+use serde_derive::{Deserialize, Serialize};
+
+#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
+pub struct Server {
+ pub listen: String,
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
+pub struct Service {
+ pub name: String,
+ pub addr: String, // host:port
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
+pub struct Client {
+ pub addr: String,
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
+pub struct Root {
+ pub server: Option<Server>,
+ pub client: Option<Client>,
+ pub service: Vec<Service>,
+ #[serde(default = "default_true")]
+ pub allow_help: bool,
+}
+
+fn default_true() -> bool {
+ true
+}
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..b588f2d
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,133 @@
+use std::io;
+use std::path::Path;
+use std::process::exit;
+
+use clap::Parser;
+use tokio::net::TcpListener;
+use tracing::{debug, error, info};
+
+use crate::arguments::Mode;
+use crate::server::SharedConfig;
+
+mod arguments;
+mod client;
+mod config;
+mod server;
+
+fn load_config<P: AsRef<Path>>(file_path: P) -> config::Root {
+ let content = match std::fs::read_to_string(file_path) {
+ Ok(s) => s,
+ Err(why) => {
+ error!("Failed to read config file: {}", why);
+ exit(1);
+ }
+ };
+ let cfg: config::Root = match toml::from_str(content.as_str()) {
+ Ok(c) => c,
+ Err(why) => {
+ error!("Failed to parse config file: {}", why);
+ exit(1);
+ }
+ };
+ cfg
+}
+
+fn validate_config(cfg: &config::Root) {
+ for s in &cfg.service {
+ if s.name.contains("\r\n") {
+ error!("Service name contains CRLF: `{}`", &s.name);
+ exit(1);
+ }
+ }
+}
+
+#[tokio::main]
+async fn main() -> io::Result<()> {
+ tracing_subscriber::fmt::init();
+
+ let args = arguments::Args::parse();
+ info!("args: {:?}", args);
+ let cfg = load_config(&args.config);
+
+ validate_config(&cfg);
+
+ match args.mode {
+ Mode::SERVER => {
+ if cfg.server == None {
+ error!("Config `server` is required in current mode");
+ exit(1);
+ }
+ if cfg.service.is_empty() {
+ error!("No service is defined");
+ exit(1);
+ }
+ server_main(cfg).await
+ }
+ Mode::CLIENT => {
+ if cfg.client == None {
+ error!("Config `client` is required in current mode");
+ exit(1);
+ }
+ client_main(cfg).await
+ }
+ }
+}
+
+async fn server_main(cfg: config::Root) -> io::Result<()> {
+ let listener = TcpListener::bind(cfg.server.unwrap().listen).await?;
+ info!("Server listening on {:?}", listener);
+ let max_srv_name_length = cfg
+ .service
+ .iter()
+ .map(|s| s.name.len())
+ .max()
+ .expect("No service was defined in config");
+ debug!("Max service name length: {}", max_srv_name_length);
+ // this immutable object lives throughout the entire process lifetime,
+ // so we use 'static here for simplicity
+ let shared_config: &'static SharedConfig = Box::leak(Box::new(SharedConfig {
+ max_srv_name_length,
+ services: cfg.service,
+ allow_help: cfg.allow_help,
+ }));
+ loop {
+ let (sock, addr) = listener.accept().await?;
+ tokio::spawn(server::handle_client(shared_config, sock, addr));
+ }
+}
+
+async fn client_server_main(
+ upstream_addr: &'static String,
+ listener: TcpListener,
+ service_name: &'static String,
+) -> io::Result<()> {
+ loop {
+ let (sock, addr) = listener.accept().await?;
+ info!("Client connected: {}", addr);
+ tokio::spawn(client::handle_client(sock, upstream_addr, service_name));
+ }
+}
+
+async fn client_main(cfg: config::Root) -> io::Result<()> {
+ let mut fut_servers = Vec::new();
+
+ let upstream_addr = Box::leak(Box::new(cfg.client.unwrap().addr));
+
+ for s in cfg.service {
+ let listener = TcpListener::bind(&s.addr).await?;
+ info!("Server listening on {:?} (service {})", listener, &s.name);
+ let service_name = Box::leak(Box::new(s.name));
+ fut_servers.push(tokio::spawn(client_server_main(
+ upstream_addr,
+ listener,
+ service_name,
+ )));
+ }
+
+ for fut in fut_servers {
+ if let Err(why) = fut.await {
+ error!("Failed to join server future: {}", why);
+ }
+ }
+ Ok(())
+}
diff --git a/src/server.rs b/src/server.rs
new file mode 100644
index 0000000..968bfeb
--- /dev/null
+++ b/src/server.rs
@@ -0,0 +1,132 @@
+use std::net::{SocketAddr, ToSocketAddrs};
+
+use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt};
+use tokio::net::{TcpSocket, TcpStream};
+use tracing::{error, info};
+
+use crate::config;
+
+pub struct SharedConfig {
+ pub max_srv_name_length: usize,
+ pub services: Vec<config::Service>,
+ pub allow_help: bool,
+}
+
+const ERR_SERVICE_NAME_NOT_FOUND: &str = "-Invalid service name\r\n";
+const ERR_SERVICE_NAME_ENCODING: &str = "-Service name is not a UTF-8 string\r\n";
+const OK: &str = "+OK\r\n";
+const SERVICE_HELP: &str = "HELP";
+
+pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: SocketAddr) {
+ info!("New client: {}", addr);
+ let cap = cfg.max_srv_name_length + 2;
+ let mut service_name_buf = vec![0_u8; cap];
+ let mut i = 0;
+ let mut ok = false;
+ while i < service_name_buf.len() {
+ service_name_buf[i] = match sock.read_u8().await {
+ Ok(b) => b,
+ Err(why) => {
+ error!("Read byte failed: {}", why);
+ return;
+ }
+ };
+ if i >= 1 && service_name_buf[i] == b'\n' && service_name_buf[i - 1] == b'\r' {
+ ok = true;
+ break;
+ }
+ i += 1;
+ }
+ if !ok {
+ error!(
+ "Service name too long, can't be a valid one: {:?}",
+ service_name_buf
+ );
+ // Cannot equal to any valid service name
+ let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await;
+ return;
+ }
+ let wanted_srv_name: &str = match std::str::from_utf8(&service_name_buf.as_slice()[..i - 1]) {
+ Ok(s) => s,
+ Err(_) => {
+ let _ = sock.write_all(ERR_SERVICE_NAME_ENCODING.as_bytes()).await;
+ return;
+ }
+ };
+ if cfg.allow_help && wanted_srv_name == SERVICE_HELP {
+ for s in &cfg.services {
+ if let Err(_) = sock.write_all(s.name.as_bytes()).await {
+ return;
+ }
+ if let Err(_) = sock.write_all(b"\r\n").await {
+ return;
+ }
+ }
+ return;
+ }
+ let mut srv = None;
+ for s in &cfg.services {
+ if s.name == wanted_srv_name {
+ srv = Some(s);
+ break;
+ }
+ }
+ if srv == None {
+ error!("Service name not found: {}", wanted_srv_name);
+ let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await;
+ return;
+ }
+ let srv = srv.unwrap();
+ info!("Client selected service: {}", srv.name);
+ if let Err(why) = sock.write_all(OK.as_bytes()).await {
+ error!("Failed to send command to client: {}", why);
+ return;
+ }
+ let remote = match srv.addr.to_socket_addrs() {
+ Ok(mut addrs) => match addrs.next() {
+ None => {
+ error!("Cannot resolve addr: {}", srv.addr);
+ return;
+ }
+ Some(sa) => sa,
+ },
+ Err(why) => {
+ error!("Failed to resolve addr: {}", why);
+ return;
+ }
+ };
+ let remote_sock;
+ if remote.is_ipv4() {
+ remote_sock = TcpSocket::new_v4();
+ } else {
+ remote_sock = TcpSocket::new_v6();
+ }
+ let remote_sock = match remote_sock {
+ Ok(rs) => rs,
+ Err(why) => {
+ error!("Failed to create socket: {}", why);
+ return;
+ }
+ };
+ let mut remote_sock = match remote_sock.connect(remote).await {
+ Ok(s) => s,
+ Err(why) => {
+ error!(
+ "Failed to connect to upstream of service {}: {}",
+ &srv.name, why
+ );
+ return;
+ }
+ };
+ match copy_bidirectional(&mut sock, &mut remote_sock).await {
+ Ok((to_right, to_left)) => {
+ info!(
+ "Proxy session finished. Bytes: client to service: {}, service to client: {}",
+ to_right, to_left
+ );
+ }
+ Err(why) => {
+ error!("Proxy connection was closed abnormally: {}", why);
+ }
+ };
+}