summaryrefslogtreecommitdiff
path: root/src/client.rs
blob: 1e2edc4344a182dd373fff7cbdbc58bc154435d4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use std::io;
use std::net::ToSocketAddrs;

use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpSocket, TcpStream};
use tracing::{error, info, warn};

pub async fn handle_client(
    mut sock: TcpStream,
    upstream: &String,
    service_name: &String,
) -> io::Result<()> {
    let remote = match upstream.to_socket_addrs() {
        Ok(mut addrs) => match addrs.next() {
            None => {
                error!("Cannot resolve addr: {}", &upstream);
                return Ok(());
            }
            Some(sa) => sa,
        },
        Err(why) => {
            error!("Failed to resolve addr: {}", why);
            return Ok(());
        }
    };
    let remote_sock;
    if remote.is_ipv4() {
        remote_sock = TcpSocket::new_v4();
    } else {
        remote_sock = TcpSocket::new_v6();
    }
    let remote_sock = match remote_sock {
        Ok(rs) => rs,
        Err(why) => {
            error!("Failed to create socket: {}", why);
            return Ok(());
        }
    };
    let mut remote_sock = match remote_sock.connect(remote).await {
        Ok(s) => s,
        Err(why) => {
            error!("Failed to connect to upstream: {}", why);
            return Ok(());
        }
    };
    send_upstream(&mut remote_sock, service_name.as_bytes()).await?;
    send_upstream(&mut remote_sock, b"\r\n").await?;
    remote_sock.flush().await?;
    let status = remote_sock.read_u8().await?;
    let mut msg = vec![0_u8; 1024];
    let mut i = 0;
    let mut prev_is_cr = false;
    loop {
        let b = remote_sock.read_u8().await?;
        if i < msg.len() {
            msg[i] = b;
        }
        if b == b'\n' && prev_is_cr {
            if i < msg.len() {
                i -= 2; // remove CRLF from reported message string
            }
            break;
        }
        i += 1;
        prev_is_cr = b == b'\r';
    }
    let msg = match std::str::from_utf8(&msg[..i]) {
        Ok(s) => s,
        Err(why) => {
            warn!(
                "Failed to decode server message as UTF-8 string, ignore: {}",
                why
            );
            "???"
        }
    };
    match status {
        b'+' => {
            info!("Upstream service selected successfully: {}", msg);
        }
        b'-' => {
            error!("Upstream responded with negative status: {}", msg);
            return Ok(());
        }
        b => {
            error!("Invalid status returned from upstream, abort: {:x?}", b);
            return Ok(());
        }
    }
    match copy_bidirectional(&mut sock, &mut remote_sock).await {
        Ok((to_right, to_left)) => {
            info!(
                "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}",
                to_right, to_left
            );
        }
        Err(why) => {
            error!("Proxy connection was closed abnormally: {}", why);
        }
    };
    Ok(())
}

async fn send_upstream(remote: &mut TcpStream, data: &[u8]) -> io::Result<()> {
    match remote.write_all(data).await {
        Ok(_) => Ok(()),
        Err(why) => {
            error!("Failed to send bytes to upstream: {}", why);
            Err(why)
        }
    }
}