1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
use std::net::{SocketAddr, ToSocketAddrs};
use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpSocket, TcpStream};
use tracing::{error, info};
use crate::config;
use crate::protocol::{ServiceName, PREFIX};
pub struct SharedConfig {
pub max_srv_name_length: usize,
pub services: Vec<config::Service>,
pub allow_help: bool,
pub no_ack_extension: bool,
}
const ERR_SERVICE_NAME_NOT_FOUND: &str = "-Invalid service name\r\n";
const ERR_SERVICE_NAME_ENCODING: &str = "-Service name is not a UTF-8 string\r\n";
const OK: &str = "+OK\r\n";
const SERVICE_HELP: &str = "HELP";
pub async fn handle_client(cfg: &SharedConfig, mut sock: TcpStream, addr: SocketAddr) {
info!("New client: {}", addr);
let cap = cfg.max_srv_name_length
+ 2
+ if cfg.no_ack_extension {
PREFIX.len()
} else {
0
};
let mut service_name_buf = vec![0_u8; cap];
let mut i = 0;
let mut ok = false;
while i < service_name_buf.len() {
service_name_buf[i] = match sock.read_u8().await {
Ok(b) => b,
Err(why) => {
error!("Read byte failed: {}", why);
return;
}
};
if i >= 1 && service_name_buf[i] == b'\n' && service_name_buf[i - 1] == b'\r' {
ok = true;
break;
}
i += 1;
}
if !ok {
error!(
"Service name too long, can't be a valid one: {:?}",
service_name_buf
);
// Cannot equal to any valid service name
let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await;
return;
}
let wanted_srv_name: &str = match std::str::from_utf8(&service_name_buf.as_slice()[..i - 1]) {
Ok(s) => s,
Err(_) => {
let _ = sock.write_all(ERR_SERVICE_NAME_ENCODING.as_bytes()).await;
return;
}
};
if cfg.allow_help && wanted_srv_name == SERVICE_HELP {
for s in &cfg.services {
if let Err(_) = sock.write_all(s.name.as_bytes()).await {
return;
}
if let Err(_) = sock.write_all(b"\r\n").await {
return;
}
}
return;
}
let wanted_srv_name = if cfg.no_ack_extension {
ServiceName::from(wanted_srv_name)
} else {
ServiceName {
service_name: wanted_srv_name,
no_ack: false,
}
};
let mut srv = None;
for s in &cfg.services {
if s.name == wanted_srv_name.service_name {
srv = Some(s);
break;
}
}
if srv == None {
error!("Service name not found: {}", wanted_srv_name.to_string());
let _ = sock.write_all(ERR_SERVICE_NAME_NOT_FOUND.as_bytes()).await;
return;
}
let srv = srv.unwrap();
info!("Client selected service: {}", srv.name);
if !wanted_srv_name.no_ack {
if let Err(why) = sock.write_all(OK.as_bytes()).await {
error!("Failed to send command to client: {}", why);
return;
}
}
let remote = match srv.addr.to_socket_addrs() {
Ok(mut addrs) => match addrs.next() {
None => {
error!("Cannot resolve addr: {}", srv.addr);
return;
}
Some(sa) => sa,
},
Err(why) => {
error!("Failed to resolve addr: {}", why);
return;
}
};
let remote_sock;
if remote.is_ipv4() {
remote_sock = TcpSocket::new_v4();
} else {
remote_sock = TcpSocket::new_v6();
}
let remote_sock = match remote_sock {
Ok(rs) => rs,
Err(why) => {
error!("Failed to create socket: {}", why);
return;
}
};
let mut remote_sock = match remote_sock.connect(remote).await {
Ok(s) => s,
Err(why) => {
error!(
"Failed to connect to upstream of service {}: {}",
&srv.name, why
);
return;
}
};
match copy_bidirectional(&mut sock, &mut remote_sock).await {
Ok((to_right, to_left)) => {
info!(
"Proxy session finished. Bytes: client to service: {}, service to client: {}",
to_right, to_left
);
}
Err(why) => {
error!("Proxy connection was closed abnormally: {}", why);
}
};
}
|