diff options
author | Keuin <[email protected]> | 2023-09-05 01:52:56 +0800 |
---|---|---|
committer | Keuin <[email protected]> | 2023-09-05 01:56:21 +0800 |
commit | 50dbc034090614d004d097c7a45b0a28a3bbb80b (patch) | |
tree | b8ad419bb8c2fed12ac419274755c716166eb90b /src | |
parent | 863473cdcb29d9989c39b4ff96bd54e14b13c6b6 (diff) |
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 66 | ||||
-rw-r--r-- | src/config.rs | 8 | ||||
-rw-r--r-- | src/main.rs | 22 | ||||
-rw-r--r-- | src/protocol.rs | 35 | ||||
-rw-r--r-- | src/server.rs | 30 |
5 files changed, 136 insertions, 25 deletions
diff --git a/src/client.rs b/src/client.rs index 3fcf78f..13def31 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,11 +5,32 @@ use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpSocket, TcpStream}; use tracing::{error, info, warn}; +use crate::client::Error::{IOError, InvalidProtocol, InvalidService}; + +use crate::protocol::ServiceName; + +pub struct SharedConfig { + pub no_ack: bool, +} + +pub enum Error { + IOError(io::Error), + InvalidService, + InvalidProtocol, +} + +impl From<io::Error> for Error { + fn from(value: io::Error) -> Self { + IOError(value) + } +} + pub async fn handle_client( + cfg: &SharedConfig, mut sock: TcpStream, upstream: &String, service_name: &String, -) -> io::Result<()> { +) -> Result<(), Error> { let remote = match upstream.to_socket_addrs() { Ok(mut addrs) => match addrs.next() { None => { @@ -43,9 +64,33 @@ pub async fn handle_client( return Ok(()); } }; - send_upstream(&mut remote_sock, service_name.as_bytes()).await?; + let service_name = ServiceName { + service_name, + no_ack: cfg.no_ack, + }; + send_upstream(&mut remote_sock, service_name.to_string().as_bytes()).await?; send_upstream(&mut remote_sock, b"\r\n").await?; remote_sock.flush().await?; + + if !cfg.no_ack { + read_status(&mut remote_sock).await?; + } + + match copy_bidirectional(&mut sock, &mut remote_sock).await { + Ok((to_right, to_left)) => { + info!( + "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}", + to_right, to_left + ); + } + Err(why) => { + error!("Proxy connection was closed abnormally: {}", why); + } + }; + Ok(()) +} + +async fn read_status(remote_sock: &mut TcpStream) -> Result<(), Error> { let status = remote_sock.read_u8().await?; let mut msg = vec![0_u8; 1024]; let mut i = 0; @@ -78,28 +123,17 @@ pub async fn handle_client( match status { b'+' => { info!("Upstream service selected successfully: {}", msg); + Ok(()) } b'-' => { error!("Upstream responded with negative status: {}", msg); - return Ok(()); + Err(InvalidService) } b => { error!("Invalid status returned from upstream, abort: {:x?}", b); - return Ok(()); + Err(InvalidProtocol) } } - match copy_bidirectional(&mut sock, &mut remote_sock).await { - Ok((to_right, to_left)) => { - info!( - "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}", - to_right, to_left - ); - } - Err(why) => { - error!("Proxy connection was closed abnormally: {}", why); - } - }; - Ok(()) } async fn send_upstream(remote: &mut TcpStream, data: &[u8]) -> io::Result<()> { diff --git a/src/config.rs b/src/config.rs index 7eee270..1fba27f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,6 +3,8 @@ use serde_derive::{Deserialize, Serialize}; #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Server { pub listen: String, + #[serde(default = "default_false")] + pub no_ack_extension: bool, } #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -14,6 +16,8 @@ pub struct Service { #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Client { pub addr: String, + #[serde(default = "default_false")] + pub no_ack: bool, } #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -28,3 +32,7 @@ pub struct Root { fn default_true() -> bool { true } + +fn default_false() -> bool { + false +} diff --git a/src/main.rs b/src/main.rs index b588f2d..da5c2f5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ use crate::server::SharedConfig; mod arguments; mod client; mod config; +mod protocol; mod server; fn load_config<P: AsRef<Path>>(file_path: P) -> config::Root { @@ -74,7 +75,8 @@ async fn main() -> io::Result<()> { } async fn server_main(cfg: config::Root) -> io::Result<()> { - let listener = TcpListener::bind(cfg.server.unwrap().listen).await?; + let server_config = cfg.server.unwrap(); + let listener = TcpListener::bind(server_config.listen).await?; info!("Server listening on {:?}", listener); let max_srv_name_length = cfg .service @@ -89,6 +91,7 @@ async fn server_main(cfg: config::Root) -> io::Result<()> { max_srv_name_length, services: cfg.service, allow_help: cfg.allow_help, + no_ack_extension: server_config.no_ack_extension, })); loop { let (sock, addr) = listener.accept().await?; @@ -97,6 +100,7 @@ async fn server_main(cfg: config::Root) -> io::Result<()> { } async fn client_server_main( + cfg: &'static client::SharedConfig, upstream_addr: &'static String, listener: TcpListener, service_name: &'static String, @@ -104,20 +108,32 @@ async fn client_server_main( loop { let (sock, addr) = listener.accept().await?; info!("Client connected: {}", addr); - tokio::spawn(client::handle_client(sock, upstream_addr, service_name)); + tokio::spawn(client::handle_client( + cfg, + sock, + upstream_addr, + service_name, + )); } } async fn client_main(cfg: config::Root) -> io::Result<()> { let mut fut_servers = Vec::new(); - let upstream_addr = Box::leak(Box::new(cfg.client.unwrap().addr)); + let client_cfg = cfg.client.unwrap(); + + let upstream_addr = Box::leak(Box::new(client_cfg.addr)); + + let shared_config = Box::leak(Box::new(client::SharedConfig { + no_ack: client_cfg.no_ack, + })); for s in cfg.service { let listener = TcpListener::bind(&s.addr).await?; info!("Server listening on {:?} (service {})", listener, &s.name); let service_name = Box::leak(Box::new(s.name)); fut_servers.push(tokio::spawn(client_server_main( + shared_config, upstream_addr, listener, service_name, diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..7bfec51 --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,35 @@ +pub const PREFIX: &str = "__nonstd_ext_no_ack_"; + +fn to_no_ack_service_name<T: AsRef<str>>(original_name: T) -> String { + PREFIX.to_owned() + original_name.as_ref() +} + +fn parse_service_name(service_name: &str) -> (&str, bool) { + if service_name.starts_with(PREFIX) { + (&service_name[PREFIX.len()..], true) + } else { + (service_name, false) + } +} + +pub struct ServiceName<'a> { + pub service_name: &'a str, + pub no_ack: bool, +} + +impl<'a> ServiceName<'a> { + pub fn from(s: &'a str) -> Self { + let (service_name, no_ack) = parse_service_name(s); + ServiceName { + service_name, + no_ack, + } + } + pub fn to_string(&self) -> String { + if self.no_ack { + to_no_ack_service_name(self.service_name).to_owned() + } else { + self.service_name.to_string() + } + } +} diff --git a/src/server.rs b/src/server.rs index 968bfeb..4e42105 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,11 +5,13 @@ use tokio::net::{TcpSocket, TcpStream}; use tracing::{error, info}; use crate::config; +use crate::protocol::{ServiceName, PREFIX}; pub struct SharedConfig { pub max_srv_name_length: usize, pub services: Vec<config::Service>, pub allow_help: bool, + pub no_ack_extension: bool, } const ERR_SERVICE_NAME_NOT_FOUND: &str = "-Invalid service name\r\n"; @@ -19,7 +21,13 @@ const SERVICE_HELP: &str = "HELP"; pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: SocketAddr) { info!("New client: {}", addr); - let cap = cfg.max_srv_name_length + 2; + let cap = cfg.max_srv_name_length + + 2 + + if cfg.no_ack_extension { + PREFIX.len() + } else { + 0 + }; let mut service_name_buf = vec![0_u8; cap]; let mut i = 0; let mut ok = false; @@ -64,23 +72,33 @@ pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: Socket } return; } + let wanted_srv_name = if cfg.no_ack_extension { + ServiceName::from(wanted_srv_name) + } else { + ServiceName { + service_name: wanted_srv_name, + no_ack: false, + } + }; let mut srv = None; for s in &cfg.services { - if s.name == wanted_srv_name { + if s.name == wanted_srv_name.service_name { srv = Some(s); break; } } if srv == None { - error!("Service name not found: {}", wanted_srv_name); + error!("Service name not found: {}", wanted_srv_name.to_string()); let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await; return; } let srv = srv.unwrap(); info!("Client selected service: {}", srv.name); - if let Err(why) = sock.write_all(OK.as_bytes()).await { - error!("Failed to send command to client: {}", why); - return; + if !wanted_srv_name.no_ack { + if let Err(why) = sock.write_all(OK.as_bytes()).await { + error!("Failed to send command to client: {}", why); + return; + } } let remote = match srv.addr.to_socket_addrs() { Ok(mut addrs) => match addrs.next() { |