summaryrefslogtreecommitdiff
path: root/src/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/server.rs')
-rw-r--r--src/server.rs132
1 files changed, 132 insertions, 0 deletions
diff --git a/src/server.rs b/src/server.rs
new file mode 100644
index 0000000..968bfeb
--- /dev/null
+++ b/src/server.rs
@@ -0,0 +1,132 @@
+use std::net::{SocketAddr, ToSocketAddrs};
+
+use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt};
+use tokio::net::{TcpSocket, TcpStream};
+use tracing::{error, info};
+
+use crate::config;
+
+pub struct SharedConfig {
+ pub max_srv_name_length: usize,
+ pub services: Vec<config::Service>,
+ pub allow_help: bool,
+}
+
+const ERR_SERVICE_NAME_NOT_FOUND: &str = "-Invalid service name\r\n";
+const ERR_SERVICE_NAME_ENCODING: &str = "-Service name is not a UTF-8 string\r\n";
+const OK: &str = "+OK\r\n";
+const SERVICE_HELP: &str = "HELP";
+
+pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: SocketAddr) {
+ info!("New client: {}", addr);
+ let cap = cfg.max_srv_name_length + 2;
+ let mut service_name_buf = vec![0_u8; cap];
+ let mut i = 0;
+ let mut ok = false;
+ while i < service_name_buf.len() {
+ service_name_buf[i] = match sock.read_u8().await {
+ Ok(b) => b,
+ Err(why) => {
+ error!("Read byte failed: {}", why);
+ return;
+ }
+ };
+ if i >= 1 && service_name_buf[i] == b'\n' && service_name_buf[i - 1] == b'\r' {
+ ok = true;
+ break;
+ }
+ i += 1;
+ }
+ if !ok {
+ error!(
+ "Service name too long, can't be a valid one: {:?}",
+ service_name_buf
+ );
+ // Cannot equal to any valid service name
+ let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await;
+ return;
+ }
+ let wanted_srv_name: &str = match std::str::from_utf8(&service_name_buf.as_slice()[..i - 1]) {
+ Ok(s) => s,
+ Err(_) => {
+ let _ = sock.write_all(ERR_SERVICE_NAME_ENCODING.as_bytes()).await;
+ return;
+ }
+ };
+ if cfg.allow_help && wanted_srv_name == SERVICE_HELP {
+ for s in &cfg.services {
+ if let Err(_) = sock.write_all(s.name.as_bytes()).await {
+ return;
+ }
+ if let Err(_) = sock.write_all(b"\r\n").await {
+ return;
+ }
+ }
+ return;
+ }
+ let mut srv = None;
+ for s in &cfg.services {
+ if s.name == wanted_srv_name {
+ srv = Some(s);
+ break;
+ }
+ }
+ if srv == None {
+ error!("Service name not found: {}", wanted_srv_name);
+ let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await;
+ return;
+ }
+ let srv = srv.unwrap();
+ info!("Client selected service: {}", srv.name);
+ if let Err(why) = sock.write_all(OK.as_bytes()).await {
+ error!("Failed to send command to client: {}", why);
+ return;
+ }
+ let remote = match srv.addr.to_socket_addrs() {
+ Ok(mut addrs) => match addrs.next() {
+ None => {
+ error!("Cannot resolve addr: {}", srv.addr);
+ return;
+ }
+ Some(sa) => sa,
+ },
+ Err(why) => {
+ error!("Failed to resolve addr: {}", why);
+ return;
+ }
+ };
+ let remote_sock;
+ if remote.is_ipv4() {
+ remote_sock = TcpSocket::new_v4();
+ } else {
+ remote_sock = TcpSocket::new_v6();
+ }
+ let remote_sock = match remote_sock {
+ Ok(rs) => rs,
+ Err(why) => {
+ error!("Failed to create socket: {}", why);
+ return;
+ }
+ };
+ let mut remote_sock = match remote_sock.connect(remote).await {
+ Ok(s) => s,
+ Err(why) => {
+ error!(
+ "Failed to connect to upstream of service {}: {}",
+ &srv.name, why
+ );
+ return;
+ }
+ };
+ match copy_bidirectional(&mut sock, &mut remote_sock).await {
+ Ok((to_right, to_left)) => {
+ info!(
+ "Proxy session finished. Bytes: client to service: {}, service to client: {}",
+ to_right, to_left
+ );
+ }
+ Err(why) => {
+ error!("Proxy connection was closed abnormally: {}", why);
+ }
+ };
+}