summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKeuin <[email protected]>2023-09-05 01:52:56 +0800
committerKeuin <[email protected]>2023-09-05 01:56:21 +0800
commit50dbc034090614d004d097c7a45b0a28a3bbb80b (patch)
treeb8ad419bb8c2fed12ac419274755c716166eb90b /src
parent863473cdcb29d9989c39b4ff96bd54e14b13c6b6 (diff)
feature: 0-rtt connection phase extensionHEADv0.2.0master
Diffstat (limited to 'src')
-rw-r--r--src/client.rs66
-rw-r--r--src/config.rs8
-rw-r--r--src/main.rs22
-rw-r--r--src/protocol.rs35
-rw-r--r--src/server.rs30
5 files changed, 136 insertions, 25 deletions
diff --git a/src/client.rs b/src/client.rs
index 3fcf78f..13def31 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -5,11 +5,32 @@ use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpSocket, TcpStream};
use tracing::{error, info, warn};
+use crate::client::Error::{IOError, InvalidProtocol, InvalidService};
+
+use crate::protocol::ServiceName;
+
+pub struct SharedConfig {
+ pub no_ack: bool,
+}
+
+pub enum Error {
+ IOError(io::Error),
+ InvalidService,
+ InvalidProtocol,
+}
+
+impl From<io::Error> for Error {
+ fn from(value: io::Error) -> Self {
+ IOError(value)
+ }
+}
+
pub async fn handle_client(
+ cfg: &SharedConfig,
mut sock: TcpStream,
upstream: &String,
service_name: &String,
-) -> io::Result<()> {
+) -> Result<(), Error> {
let remote = match upstream.to_socket_addrs() {
Ok(mut addrs) => match addrs.next() {
None => {
@@ -43,9 +64,33 @@ pub async fn handle_client(
return Ok(());
}
};
- send_upstream(&mut remote_sock, service_name.as_bytes()).await?;
+ let service_name = ServiceName {
+ service_name,
+ no_ack: cfg.no_ack,
+ };
+ send_upstream(&mut remote_sock, service_name.to_string().as_bytes()).await?;
send_upstream(&mut remote_sock, b"\r\n").await?;
remote_sock.flush().await?;
+
+ if !cfg.no_ack {
+ read_status(&mut remote_sock).await?;
+ }
+
+ match copy_bidirectional(&mut sock, &mut remote_sock).await {
+ Ok((to_right, to_left)) => {
+ info!(
+ "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}",
+ to_right, to_left
+ );
+ }
+ Err(why) => {
+ error!("Proxy connection was closed abnormally: {}", why);
+ }
+ };
+ Ok(())
+}
+
+async fn read_status(remote_sock: &mut TcpStream) -> Result<(), Error> {
let status = remote_sock.read_u8().await?;
let mut msg = vec![0_u8; 1024];
let mut i = 0;
@@ -78,28 +123,17 @@ pub async fn handle_client(
match status {
b'+' => {
info!("Upstream service selected successfully: {}", msg);
+ Ok(())
}
b'-' => {
error!("Upstream responded with negative status: {}", msg);
- return Ok(());
+ Err(InvalidService)
}
b => {
error!("Invalid status returned from upstream, abort: {:x?}", b);
- return Ok(());
+ Err(InvalidProtocol)
}
}
- match copy_bidirectional(&mut sock, &mut remote_sock).await {
- Ok((to_right, to_left)) => {
- info!(
- "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}",
- to_right, to_left
- );
- }
- Err(why) => {
- error!("Proxy connection was closed abnormally: {}", why);
- }
- };
- Ok(())
}
async fn send_upstream(remote: &mut TcpStream, data: &[u8]) -> io::Result<()> {
diff --git a/src/config.rs b/src/config.rs
index 7eee270..1fba27f 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -3,6 +3,8 @@ use serde_derive::{Deserialize, Serialize};
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Server {
pub listen: String,
+ #[serde(default = "default_false")]
+ pub no_ack_extension: bool,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
@@ -14,6 +16,8 @@ pub struct Service {
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Client {
pub addr: String,
+ #[serde(default = "default_false")]
+ pub no_ack: bool,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
@@ -28,3 +32,7 @@ pub struct Root {
fn default_true() -> bool {
true
}
+
+fn default_false() -> bool {
+ false
+}
diff --git a/src/main.rs b/src/main.rs
index b588f2d..da5c2f5 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -12,6 +12,7 @@ use crate::server::SharedConfig;
mod arguments;
mod client;
mod config;
+mod protocol;
mod server;
fn load_config<P: AsRef<Path>>(file_path: P) -> config::Root {
@@ -74,7 +75,8 @@ async fn main() -> io::Result<()> {
}
async fn server_main(cfg: config::Root) -> io::Result<()> {
- let listener = TcpListener::bind(cfg.server.unwrap().listen).await?;
+ let server_config = cfg.server.unwrap();
+ let listener = TcpListener::bind(server_config.listen).await?;
info!("Server listening on {:?}", listener);
let max_srv_name_length = cfg
.service
@@ -89,6 +91,7 @@ async fn server_main(cfg: config::Root) -> io::Result<()> {
max_srv_name_length,
services: cfg.service,
allow_help: cfg.allow_help,
+ no_ack_extension: server_config.no_ack_extension,
}));
loop {
let (sock, addr) = listener.accept().await?;
@@ -97,6 +100,7 @@ async fn server_main(cfg: config::Root) -> io::Result<()> {
}
async fn client_server_main(
+ cfg: &'static client::SharedConfig,
upstream_addr: &'static String,
listener: TcpListener,
service_name: &'static String,
@@ -104,20 +108,32 @@ async fn client_server_main(
loop {
let (sock, addr) = listener.accept().await?;
info!("Client connected: {}", addr);
- tokio::spawn(client::handle_client(sock, upstream_addr, service_name));
+ tokio::spawn(client::handle_client(
+ cfg,
+ sock,
+ upstream_addr,
+ service_name,
+ ));
}
}
async fn client_main(cfg: config::Root) -> io::Result<()> {
let mut fut_servers = Vec::new();
- let upstream_addr = Box::leak(Box::new(cfg.client.unwrap().addr));
+ let client_cfg = cfg.client.unwrap();
+
+ let upstream_addr = Box::leak(Box::new(client_cfg.addr));
+
+ let shared_config = Box::leak(Box::new(client::SharedConfig {
+ no_ack: client_cfg.no_ack,
+ }));
for s in cfg.service {
let listener = TcpListener::bind(&s.addr).await?;
info!("Server listening on {:?} (service {})", listener, &s.name);
let service_name = Box::leak(Box::new(s.name));
fut_servers.push(tokio::spawn(client_server_main(
+ shared_config,
upstream_addr,
listener,
service_name,
diff --git a/src/protocol.rs b/src/protocol.rs
new file mode 100644
index 0000000..7bfec51
--- /dev/null
+++ b/src/protocol.rs
@@ -0,0 +1,35 @@
+pub const PREFIX: &str = "__nonstd_ext_no_ack_";
+
+fn to_no_ack_service_name<T: AsRef<str>>(original_name: T) -> String {
+ PREFIX.to_owned() + original_name.as_ref()
+}
+
+fn parse_service_name(service_name: &str) -> (&str, bool) {
+ if service_name.starts_with(PREFIX) {
+ (&service_name[PREFIX.len()..], true)
+ } else {
+ (service_name, false)
+ }
+}
+
+pub struct ServiceName<'a> {
+ pub service_name: &'a str,
+ pub no_ack: bool,
+}
+
+impl<'a> ServiceName<'a> {
+ pub fn from(s: &'a str) -> Self {
+ let (service_name, no_ack) = parse_service_name(s);
+ ServiceName {
+ service_name,
+ no_ack,
+ }
+ }
+ pub fn to_string(&self) -> String {
+ if self.no_ack {
+ to_no_ack_service_name(self.service_name).to_owned()
+ } else {
+ self.service_name.to_string()
+ }
+ }
+}
diff --git a/src/server.rs b/src/server.rs
index 968bfeb..4e42105 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -5,11 +5,13 @@ use tokio::net::{TcpSocket, TcpStream};
use tracing::{error, info};
use crate::config;
+use crate::protocol::{ServiceName, PREFIX};
pub struct SharedConfig {
pub max_srv_name_length: usize,
pub services: Vec<config::Service>,
pub allow_help: bool,
+ pub no_ack_extension: bool,
}
const ERR_SERVICE_NAME_NOT_FOUND: &str = "-Invalid service name\r\n";
@@ -19,7 +21,13 @@ const SERVICE_HELP: &str = "HELP";
pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: SocketAddr) {
info!("New client: {}", addr);
- let cap = cfg.max_srv_name_length + 2;
+ let cap = cfg.max_srv_name_length
+ + 2
+ + if cfg.no_ack_extension {
+ PREFIX.len()
+ } else {
+ 0
+ };
let mut service_name_buf = vec![0_u8; cap];
let mut i = 0;
let mut ok = false;
@@ -64,23 +72,33 @@ pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: Socket
}
return;
}
+ let wanted_srv_name = if cfg.no_ack_extension {
+ ServiceName::from(wanted_srv_name)
+ } else {
+ ServiceName {
+ service_name: wanted_srv_name,
+ no_ack: false,
+ }
+ };
let mut srv = None;
for s in &cfg.services {
- if s.name == wanted_srv_name {
+ if s.name == wanted_srv_name.service_name {
srv = Some(s);
break;
}
}
if srv == None {
- error!("Service name not found: {}", wanted_srv_name);
+ error!("Service name not found: {}", wanted_srv_name.to_string());
let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await;
return;
}
let srv = srv.unwrap();
info!("Client selected service: {}", srv.name);
- if let Err(why) = sock.write_all(OK.as_bytes()).await {
- error!("Failed to send command to client: {}", why);
- return;
+ if !wanted_srv_name.no_ack {
+ if let Err(why) = sock.write_all(OK.as_bytes()).await {
+ error!("Failed to send command to client: {}", why);
+ return;
+ }
}
let remote = match srv.addr.to_socket_addrs() {
Ok(mut addrs) => match addrs.next() {