summaryrefslogtreecommitdiff
path: root/src/client.rs
blob: 13def311b2158b0f34d7fcbe27902ba56cf58224 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use std::io;
use std::net::ToSocketAddrs;

use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpSocket, TcpStream};
use tracing::{error, info, warn};

use crate::client::Error::{IOError, InvalidProtocol, InvalidService};

use crate::protocol::ServiceName;

pub struct SharedConfig {
    pub no_ack: bool,
}

pub enum Error {
    IOError(io::Error),
    InvalidService,
    InvalidProtocol,
}

impl From<io::Error> for Error {
    fn from(value: io::Error) -> Self {
        IOError(value)
    }
}

pub async fn handle_client(
    cfg: &SharedConfig,
    mut sock: TcpStream,
    upstream: &String,
    service_name: &String,
) -> Result<(), Error> {
    let remote = match upstream.to_socket_addrs() {
        Ok(mut addrs) => match addrs.next() {
            None => {
                error!("Cannot resolve addr: {}", &upstream);
                return Ok(());
            }
            Some(sa) => sa,
        },
        Err(why) => {
            error!("Failed to resolve addr: {}", why);
            return Ok(());
        }
    };
    let remote_sock;
    if remote.is_ipv4() {
        remote_sock = TcpSocket::new_v4();
    } else {
        remote_sock = TcpSocket::new_v6();
    }
    let remote_sock = match remote_sock {
        Ok(rs) => rs,
        Err(why) => {
            error!("Failed to create socket: {}", why);
            return Ok(());
        }
    };
    let mut remote_sock = match remote_sock.connect(remote).await {
        Ok(s) => s,
        Err(why) => {
            error!("Failed to connect to upstream: {}", why);
            return Ok(());
        }
    };
    let service_name = ServiceName {
        service_name,
        no_ack: cfg.no_ack,
    };
    send_upstream(&mut remote_sock, service_name.to_string().as_bytes()).await?;
    send_upstream(&mut remote_sock, b"\r\n").await?;
    remote_sock.flush().await?;

    if !cfg.no_ack {
        read_status(&mut remote_sock).await?;
    }

    match copy_bidirectional(&mut sock, &mut remote_sock).await {
        Ok((to_right, to_left)) => {
            info!(
                "Proxy session finished. Bytes: client to upstream: {}, upstream to client: {}",
                to_right, to_left
            );
        }
        Err(why) => {
            error!("Proxy connection was closed abnormally: {}", why);
        }
    };
    Ok(())
}

async fn read_status(remote_sock: &mut TcpStream) -> Result<(), Error> {
    let status = remote_sock.read_u8().await?;
    let mut msg = vec![0_u8; 1024];
    let mut i = 0;
    let mut prev_is_cr = false;
    loop {
        let b = remote_sock.read_u8().await?;
        if i < msg.len() {
            msg[i] = b;
        }
        if b == b'\n' && prev_is_cr {
            if i < msg.len() {
                i -= 1; // remove CRLF from reported message string
            }
            break;
        }
        i += 1;
        prev_is_cr = b == b'\r';
    }
    let msg = if i < msg.len() { &msg[..i] } else { &msg };
    let msg = match std::str::from_utf8(msg) {
        Ok(s) => s,
        Err(why) => {
            warn!(
                "Failed to decode server message as UTF-8 string, ignore: {}",
                why
            );
            "???"
        }
    };
    match status {
        b'+' => {
            info!("Upstream service selected successfully: {}", msg);
            Ok(())
        }
        b'-' => {
            error!("Upstream responded with negative status: {}", msg);
            Err(InvalidService)
        }
        b => {
            error!("Invalid status returned from upstream, abort: {:x?}", b);
            Err(InvalidProtocol)
        }
    }
}

async fn send_upstream(remote: &mut TcpStream, data: &[u8]) -> io::Result<()> {
    match remote.write_all(data).await {
        Ok(_) => Ok(()),
        Err(why) => {
            error!("Failed to send bytes to upstream: {}", why);
            Err(why)
        }
    }
}