diff --git a/Cargo.lock b/Cargo.lock index 432bb32..7e3e4d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,6 +160,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + [[package]] name = "arboard" version = "3.6.1" @@ -173,7 +179,7 @@ dependencies = [ "objc2-foundation 0.3.2", "parking_lot", "percent-encoding", - "windows-sys 0.52.0", + "windows-sys 0.60.2", "x11rb", ] @@ -556,6 +562,17 @@ dependencies = [ "libc", ] +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.1", +] + [[package]] name = "chrono" version = "0.4.44" @@ -753,6 +770,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -1070,18 +1096,6 @@ dependencies = [ "serde", ] -[[package]] -name = "enum-as-inner" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "enumn" version = "0.1.14" @@ -1417,8 +1431,22 @@ checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "libc", - "r-efi", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "rand_core 0.10.1", "wasip2", + "wasip3", ] [[package]] @@ -1654,25 +1682,19 @@ checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" [[package]] name = "hickory-proto" -version = "0.25.2" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8a6fe56c0038198998a6f217ca4e7ef3a5e51f46163bd6dd60b5c71ca6c6502" +checksum = "a916d0494600d99ecb15aadfab677ad97c4de559e8f1af0c129353a733ac1fcc" dependencies = [ - "async-trait", - "cfg-if", "data-encoding", - "enum-as-inner", - "futures-channel", - "futures-io", - "futures-util", "idna", "ipnet", + "jni 0.22.4", "once_cell", - "rand 0.9.4", + "rand 0.10.1", "ring", "thiserror 2.0.18", "tinyvec", - "tokio", "tracing", "url", ] @@ -1810,6 +1832,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "idna" version = "1.1.0" @@ -1852,6 +1880,8 @@ checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", "hashbrown 0.17.0", + "serde", + "serde_core", ] [[package]] @@ -2030,6 +2060,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.185" @@ -2861,6 +2897,16 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.117", +] + [[package]] name = "proc-macro-crate" version = "3.5.0" @@ -2915,6 +2961,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "rand" version = "0.8.6" @@ -2936,6 +2988,17 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" +dependencies = [ + "chacha20", + "getrandom 0.4.2", + "rand_core 0.10.1", +] + [[package]] name = "rand_chacha" version = "0.3.1" @@ -2974,6 +3037,12 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" + [[package]] name = "raw-window-handle" version = "0.5.2" @@ -3163,7 +3232,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3305,7 +3374,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -3461,9 +3530,9 @@ dependencies = [ [[package]] name = "socks5-impl" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1eae7c78f163b7805f66493c787d7bad4816146faf0cf655d57c78b90c383ce3" +checksum = "150816c2d954315f351129f438f851285e1ddb6d6ccc850ddd45c523d19abda0" dependencies = [ "async-trait", "bytes", @@ -3567,7 +3636,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.3.4", + "getrandom 0.4.2", "once_cell", "rustix 1.1.4", "windows-sys 0.61.2", @@ -3891,9 +3960,8 @@ dependencies = [ [[package]] name = "tun2proxy" -version = "0.7.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0576f75fd691ad86cdc4348f29fb8770037ab8140179f1f9f8f6991f7ebd2176" +version = "0.7.21" +source = "git+https://github.com/yyoyoian-pixel/tun2proxy?branch=feat%2Fudpgw-jni-param#dfc24ed12cdee69987bdd321ea55c6b940f2d0f0" dependencies = [ "android_logger", "async-trait", @@ -4061,7 +4129,16 @@ version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", ] [[package]] @@ -4119,6 +4196,40 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags 2.11.1", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "wayland-backend" version = "0.3.15" @@ -4467,7 +4578,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.48.0", ] [[package]] @@ -4945,12 +5056,100 @@ dependencies = [ "winreg", ] +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + [[package]] name = "wit-bindgen" version = "0.57.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn 2.0.117", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.117", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags 2.11.1", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + [[package]] name = "writeable" version = "0.6.3" diff --git a/Cargo.toml b/Cargo.toml index 01af7c4..c623ed8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,12 +94,18 @@ libc = "0.2" # traffic black-holes (symptom: Chrome shows DNS_PROBE_STARTED). [target.'cfg(target_os = "android")'.dependencies] jni = { version = "0.21", default-features = false } -tun2proxy = { version = "0.7", default-features = false } +tun2proxy = { version = "0.7", default-features = false, features = ["udpgw"] } [dev-dependencies] # Used in mitm tests to sanity-check the cert extensions we emit. x509-parser = "0.16" +# Temporary patch: adds udpgw_server parameter to the Android JNI run() +# function. Upstream PR: https://github.com/tun2proxy/tun2proxy/pull/247 +# Remove this section once tun2proxy >= 0.8 ships with the change. +[patch.crates-io] +tun2proxy = { git = "https://github.com/yyoyoian-pixel/tun2proxy", branch = "feat/udpgw-jni-param" } + [profile.release] panic = "abort" codegen-units = 1 diff --git a/android/app/src/main/java/com/github/shadowsocks/bg/Tun2proxy.kt b/android/app/src/main/java/com/github/shadowsocks/bg/Tun2proxy.kt index 03953be..4b1e3bf 100644 --- a/android/app/src/main/java/com/github/shadowsocks/bg/Tun2proxy.kt +++ b/android/app/src/main/java/com/github/shadowsocks/bg/Tun2proxy.kt @@ -59,6 +59,7 @@ object Tun2proxy { tunMtu: Char, verbosity: Int, dnsStrategy: Int, + udpgwServer: String, ): Int /** Signals the running `run()` to shut down. Idempotent. */ diff --git a/android/app/src/main/java/com/therealaleph/mhrv/MhrvVpnService.kt b/android/app/src/main/java/com/therealaleph/mhrv/MhrvVpnService.kt index 121b9af..f8219b3 100644 --- a/android/app/src/main/java/com/therealaleph/mhrv/MhrvVpnService.kt +++ b/android/app/src/main/java/com/therealaleph/mhrv/MhrvVpnService.kt @@ -249,6 +249,10 @@ class MhrvVpnService : VpnService() { // the sole owner once it's running. val detachedFd = parcelFd.detachFd() tun2proxyRunning.set(true) + // In full mode, enable udpgw so UDP traffic (DNS, QUIC, …) is + // forwarded through the tunnel-node's native udpgw handler. + // 198.18.0.1:7300 is a magic address the tunnel-node intercepts. + val udpgwAddr = if (cfg.mode == Mode.FULL) "198.18.0.1:7300" else "" val worker = Thread({ try { val rc = Tun2proxy.run( @@ -258,6 +262,7 @@ class MhrvVpnService : VpnService() { MTU.toChar(), /* verbosity = info */ 3, /* dnsStrategy = virtual */ 0, + udpgwAddr, ) Log.i(TAG, "tun2proxy exited rc=$rc") } catch (t: Throwable) { diff --git a/tunnel-node/src/main.rs b/tunnel-node/src/main.rs index e03ff5e..2200624 100644 --- a/tunnel-node/src/main.rs +++ b/tunnel-node/src/main.rs @@ -22,12 +22,14 @@ use axum::{routing::post, Json, Router}; use base64::engine::general_purpose::STANDARD as B64; use base64::Engine; use serde::{Deserialize, Serialize}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::net::tcp::OwnedWriteHalf; use tokio::net::{lookup_host, TcpStream, UdpSocket}; use tokio::sync::{mpsc, Mutex, Notify}; use tokio::task::JoinSet; +mod udpgw; + /// Structured error code returned when the tunnel-node receives an op it /// doesn't recognize. Clients use this (rather than string-matching `e`) to /// detect a version mismatch and gracefully fall back. @@ -95,8 +97,30 @@ const UDP_QUEUE_DROP_LOG_STRIDE: u64 = 100; // Session // --------------------------------------------------------------------------- +/// Writer half — either a real TCP socket or an in-process duplex channel +/// (used for virtual sessions like udpgw). +enum SessionWriter { + Tcp(OwnedWriteHalf), + Duplex(tokio::io::WriteHalf), +} + +impl SessionWriter { + async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + match self { + SessionWriter::Tcp(w) => w.write_all(buf).await, + SessionWriter::Duplex(w) => w.write_all(buf).await, + } + } + async fn flush(&mut self) -> std::io::Result<()> { + match self { + SessionWriter::Tcp(w) => w.flush().await, + SessionWriter::Duplex(w) => w.flush().await, + } + } +} + struct SessionInner { - writer: Mutex, + writer: Mutex, read_buf: Mutex>, eof: AtomicBool, last_active: Mutex, @@ -110,6 +134,17 @@ struct SessionInner { struct ManagedSession { inner: Arc, reader_handle: tokio::task::JoinHandle<()>, + /// For udpgw sessions, the server task handle (so we can abort on close). + udpgw_handle: Option>, +} + +impl ManagedSession { + fn abort_all(&self) { + self.reader_handle.abort(); + if let Some(ref h) = self.udpgw_handle { + h.abort(); + } + } } /// UDP equivalent of `SessionInner`. Holds a *connected* `UdpSocket` @@ -148,7 +183,7 @@ async fn create_session(host: &str, port: u16) -> std::io::Result std::io::Result ManagedSession { + let (client_half, server_half) = tokio::io::duplex(65536); + let (read_half, write_half) = tokio::io::split(client_half); + + let inner = Arc::new(SessionInner { + writer: Mutex::new(SessionWriter::Duplex(write_half)), + read_buf: Mutex::new(Vec::with_capacity(32768)), + eof: AtomicBool::new(false), + last_active: Mutex::new(Instant::now()), + notify: Notify::new(), + }); + + let inner_ref = inner.clone(); + let reader_handle = tokio::spawn(reader_task(read_half, inner_ref)); + let udpgw_handle = Some(tokio::spawn(udpgw::udpgw_server_task(server_half))); + + ManagedSession { inner, reader_handle, udpgw_handle } } -async fn reader_task(mut reader: OwnedReadHalf, session: Arc) { +async fn reader_task(mut reader: impl AsyncRead + Unpin, session: Arc) { let mut buf = vec![0u8; 65536]; loop { match reader.read(&mut buf).await { @@ -971,9 +1026,13 @@ async fn handle_connect(state: &AppState, host: Option, port: Option v, Err(r) => return r, }; - let session = match create_session(&host, port).await { - Ok(s) => s, - Err(e) => return TunnelResponse::error(format!("connect failed: {}", e)), + let session = if udpgw::is_udpgw_dest(&host, port) { + create_udpgw_session() + } else { + match create_session(&host, port).await { + Ok(s) => s, + Err(e) => return TunnelResponse::error(format!("connect failed: {}", e)), + } }; let sid = uuid::Uuid::new_v4().to_string(); tracing::info!("session {} -> {}:{}", sid, host, port); @@ -995,9 +1054,13 @@ async fn handle_connect_data_phase1( ) -> Result<(String, Arc), TunnelResponse> { let (host, port) = validate_host_port(host, port)?; - let session = create_session(&host, port) - .await - .map_err(|e| TunnelResponse::error(format!("connect failed: {}", e)))?; + let session = if udpgw::is_udpgw_dest(&host, port) { + create_udpgw_session() + } else { + create_session(&host, port) + .await + .map_err(|e| TunnelResponse::error(format!("connect failed: {}", e)))? + }; // Any failure below this point must abort the reader task, otherwise // the newly-opened upstream TCP connection would leak. Keep the @@ -1146,7 +1209,7 @@ async fn handle_close(state: &AppState, sid: Option) -> TunnelResponse { _ => return TunnelResponse::error("missing sid"), }; if let Some(s) = state.sessions.lock().await.remove(&sid) { - s.reader_handle.abort(); + s.abort_all(); tracing::info!("session {} closed by client", sid); } if let Some(s) = state.udp_sessions.lock().await.remove(&sid) { @@ -1430,7 +1493,7 @@ mod tests { let (_reader, writer) = client.into_split(); Arc::new(SessionInner { - writer: Mutex::new(writer), + writer: Mutex::new(SessionWriter::Tcp(writer)), read_buf: Mutex::new(Vec::new()), eof: AtomicBool::new(false), last_active: Mutex::new(Instant::now()), @@ -1597,7 +1660,7 @@ mod tests { let stream = TcpStream::connect(addr).await.unwrap(); let (reader, writer) = stream.into_split(); let inner = Arc::new(SessionInner { - writer: Mutex::new(writer), + writer: Mutex::new(SessionWriter::Tcp(writer)), read_buf: Mutex::new(Vec::new()), eof: AtomicBool::new(false), last_active: Mutex::new(Instant::now()), diff --git a/tunnel-node/src/udpgw.rs b/tunnel-node/src/udpgw.rs new file mode 100644 index 0000000..3c6e180 --- /dev/null +++ b/tunnel-node/src/udpgw.rs @@ -0,0 +1,512 @@ +//! Native implementation of the tun2proxy udpgw wire protocol. +//! +//! Wire format (all fields big-endian): +//! ```text +//! +-----+-------+---------+------+----------+----------+----------+ +//! | LEN | FLAGS | CONN_ID | ATYP | DST.ADDR | DST.PORT | DATA | +//! +-----+-------+---------+------+----------+----------+----------+ +//! | 2 | 1 | 2 | 1 | Variable | 2 | Variable | +//! +-----+-------+---------+------+----------+----------+----------+ +//! ``` +//! +//! Flags: KEEPALIVE=0x01, DATA=0x02, ERR=0x20 +//! ATYP: 0x01=IPv4(4B), 0x03=Domain(1B len + name), 0x04=IPv6(16B) + +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::sync::Arc; + +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; +use tokio::net::UdpSocket; + +/// Magic address that the client connects to via the tunnel protocol. +/// `198.18.0.0/15` is reserved for benchmarking (RFC 2544) and will +/// never be a real destination. +pub const UDPGW_MAGIC_IP: [u8; 4] = [198, 18, 0, 1]; +pub const UDPGW_MAGIC_PORT: u16 = 7300; + +const FLAG_KEEPALIVE: u8 = 0x01; +const FLAG_DATA: u8 = 0x02; +const FLAG_ERR: u8 = 0x20; + +const ATYP_IPV4: u8 = 0x01; +const ATYP_DOMAIN: u8 = 0x03; +const ATYP_IPV6: u8 = 0x04; + +/// Maximum UDP payload we'll handle. +const UDP_MTU: usize = 10240; + +// ------------------------------------------------------------------------- +// Frame types +// ------------------------------------------------------------------------- + +#[derive(Debug, Clone)] +enum DstAddr { + V4(Ipv4Addr, u16), + V6(Ipv6Addr, u16), + Domain(String, u16), +} + +impl DstAddr { + fn to_socket_addr(&self) -> std::io::Result { + match self { + DstAddr::V4(ip, port) => Ok(SocketAddr::V4(SocketAddrV4::new(*ip, *port))), + DstAddr::V6(ip, port) => Ok(SocketAddr::V6(SocketAddrV6::new(*ip, *port, 0, 0))), + DstAddr::Domain(name, port) => { + use std::net::ToSocketAddrs; + (name.as_str(), *port) + .to_socket_addrs()? + .next() + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, "DNS resolution failed")) + } + } + } + + /// Serialise into SOCKS5 address format: ATYP + addr + port. + fn write_to(&self, buf: &mut Vec) { + match self { + DstAddr::V4(ip, port) => { + buf.push(ATYP_IPV4); + buf.extend_from_slice(&ip.octets()); + buf.extend_from_slice(&port.to_be_bytes()); + } + DstAddr::V6(ip, port) => { + buf.push(ATYP_IPV6); + buf.extend_from_slice(&ip.octets()); + buf.extend_from_slice(&port.to_be_bytes()); + } + DstAddr::Domain(name, port) => { + buf.push(ATYP_DOMAIN); + buf.push(name.len() as u8); + buf.extend_from_slice(name.as_bytes()); + buf.extend_from_slice(&port.to_be_bytes()); + } + } + } + + fn serialised_len(&self) -> usize { + match self { + DstAddr::V4(..) => 1 + 4 + 2, // ATYP + IPv4 + port + DstAddr::V6(..) => 1 + 16 + 2, // ATYP + IPv6 + port + DstAddr::Domain(n, _) => 1 + 1 + n.len() + 2, // ATYP + len + name + port + } + } +} + +#[derive(Debug)] +struct Frame { + flags: u8, + conn_id: u16, + addr: Option, + payload: Vec, +} + +// ------------------------------------------------------------------------- +// Parse / serialise +// ------------------------------------------------------------------------- + +/// Try to parse one frame from `buf`. Returns `(frame, bytes_consumed)` or +/// `None` if the buffer doesn't contain a complete frame yet. +fn try_parse_frame(buf: &[u8]) -> Result, std::io::Error> { + if buf.len() < 2 { + return Ok(None); + } + let body_len = u16::from_be_bytes([buf[0], buf[1]]) as usize; + let total = 2 + body_len; + if buf.len() < total { + return Ok(None); + } + + let body = &buf[2..total]; + if body.len() < 3 { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "frame too short")); + } + let flags = body[0]; + let conn_id = u16::from_be_bytes([body[1], body[2]]); + let rest = &body[3..]; + + let (addr, payload_start) = if flags & FLAG_DATA != 0 { + // Parse SOCKS5-style address. + if rest.is_empty() { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "missing ATYP")); + } + let atyp = rest[0]; + match atyp { + ATYP_IPV4 => { + if rest.len() < 1 + 4 + 2 { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "short IPv4 addr")); + } + let ip = Ipv4Addr::new(rest[1], rest[2], rest[3], rest[4]); + let port = u16::from_be_bytes([rest[5], rest[6]]); + (Some(DstAddr::V4(ip, port)), 7) + } + ATYP_IPV6 => { + if rest.len() < 1 + 16 + 2 { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "short IPv6 addr")); + } + let mut octets = [0u8; 16]; + octets.copy_from_slice(&rest[1..17]); + let ip = Ipv6Addr::from(octets); + let port = u16::from_be_bytes([rest[17], rest[18]]); + (Some(DstAddr::V6(ip, port)), 19) + } + ATYP_DOMAIN => { + if rest.len() < 2 { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "short domain addr")); + } + let dlen = rest[1] as usize; + if rest.len() < 2 + dlen + 2 { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "short domain addr")); + } + let name = String::from_utf8_lossy(&rest[2..2 + dlen]).into_owned(); + let port = u16::from_be_bytes([rest[2 + dlen], rest[3 + dlen]]); + (Some(DstAddr::Domain(name, port)), 2 + dlen + 2) + } + _ => { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("unknown ATYP 0x{:02x}", atyp))); + } + } + } else { + (None, 0) + }; + + let payload = rest[payload_start..].to_vec(); + + Ok(Some((Frame { flags, conn_id, addr, payload }, total))) +} + +fn serialise_frame(frame: &Frame) -> Vec { + // Body = flags(1) + conn_id(2) + [addr] + payload + let addr_len = frame.addr.as_ref().map_or(0, |a| a.serialised_len()); + let body_len = 1 + 2 + addr_len + frame.payload.len(); + + let mut buf = Vec::with_capacity(2 + body_len); + buf.extend_from_slice(&(body_len as u16).to_be_bytes()); + buf.push(frame.flags); + buf.extend_from_slice(&frame.conn_id.to_be_bytes()); + if let Some(ref addr) = frame.addr { + addr.write_to(&mut buf); + } + buf.extend_from_slice(&frame.payload); + buf +} + +// ------------------------------------------------------------------------- +// Public API +// ------------------------------------------------------------------------- + +/// Returns `true` if the connect destination is the magic udpgw address. +pub fn is_udpgw_dest(host: &str, port: u16) -> bool { + port == UDPGW_MAGIC_PORT && host == format!("{}.{}.{}.{}", UDPGW_MAGIC_IP[0], UDPGW_MAGIC_IP[1], UDPGW_MAGIC_IP[2], UDPGW_MAGIC_IP[3]) +} + +/// Per-conn_id persistent UDP socket with a background reader that +/// continuously receives datagrams and queues response frames. +struct ConnSocket { + sock: Arc, + _reader: tokio::task::AbortHandle, +} + +/// Run the udpgw server over a duplex stream. Reads udpgw frames from the +/// client half, sends real UDP datagrams, and writes response frames back. +/// Maintains persistent sockets per conn_id so Telegram VoIP (which expects +/// a stable source port) works correctly. +pub async fn udpgw_server_task(stream: DuplexStream) { + let (tx, mut rx) = tokio::sync::mpsc::channel::>(256); + + // Writer task: drains response channel → duplex stream. + let mut read_half = { + let (read_half, write_half) = tokio::io::split(stream); + tokio::spawn(async move { + let mut w = write_half; + while let Some(data) = rx.recv().await { + if w.write_all(&data).await.is_err() { + break; + } + let _ = w.flush().await; + } + }); + read_half + }; + + // Persistent sockets keyed by (conn_id, dest_addr). + let mut sockets: std::collections::HashMap<(u16, SocketAddr), ConnSocket> = std::collections::HashMap::new(); + + let mut buf = Vec::with_capacity(65536); + let mut tmp = [0u8; 65536]; + + loop { + let n = match read_half.read(&mut tmp).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + buf.extend_from_slice(&tmp[..n]); + + loop { + match try_parse_frame(&buf) { + Ok(Some((frame, consumed))) => { + buf.drain(..consumed); + handle_frame(&frame, &tx, &mut sockets).await; + } + Ok(None) => break, + Err(e) => { + tracing::warn!("udpgw frame parse error: {}", e); + if buf.len() >= 2 { + let skip = 2 + u16::from_be_bytes([buf[0], buf[1]]) as usize; + buf.drain(..skip.min(buf.len())); + } else { + buf.clear(); + } + break; + } + } + } + } + + // AbortHandle::drop aborts each reader task automatically. + drop(sockets); + tracing::debug!("udpgw session ended"); +} + +/// Get or create a persistent UDP socket for this (conn_id, dest_addr) pair. +/// A background reader task continuously receives datagrams and queues +/// response frames — no per-packet timeout needed. +async fn get_or_create_socket( + conn_id: u16, + dst: &SocketAddr, + addr: &DstAddr, + tx: &tokio::sync::mpsc::Sender>, + sockets: &mut std::collections::HashMap<(u16, SocketAddr), ConnSocket>, +) -> Option> { + let key = (conn_id, *dst); + if let Some(cs) = sockets.get(&key) { + return Some(cs.sock.clone()); + } + + let bind_addr: SocketAddr = if dst.is_ipv6() { + "[::]:0".parse().unwrap() + } else { + "0.0.0.0:0".parse().unwrap() + }; + let sock = match UdpSocket::bind(bind_addr).await { + Ok(s) => Arc::new(s), + Err(e) => { + tracing::debug!("udpgw bind failed: {}", e); + return None; + } + }; + if let Err(e) = sock.connect(dst).await { + tracing::debug!("udpgw connect {} failed: {}", dst, e); + return None; + } + + // Spawn continuous reader for this socket. + let sock_clone = sock.clone(); + let tx_clone = tx.clone(); + let addr_clone = addr.clone(); + let reader = tokio::spawn(async move { + let mut recv_buf = vec![0u8; UDP_MTU]; + loop { + match sock_clone.recv(&mut recv_buf).await { + Ok(n) => { + let resp = serialise_frame(&Frame { + flags: FLAG_DATA, + conn_id, + addr: Some(addr_clone.clone()), + payload: recv_buf[..n].to_vec(), + }); + if tx_clone.send(resp).await.is_err() { + break; + } + } + Err(_) => break, + } + } + }); + + sockets.insert(key, ConnSocket { sock: sock.clone(), _reader: reader.abort_handle() }); + Some(sock) +} + +async fn handle_frame( + frame: &Frame, + tx: &tokio::sync::mpsc::Sender>, + sockets: &mut std::collections::HashMap<(u16, SocketAddr), ConnSocket>, +) { + if frame.flags & FLAG_KEEPALIVE != 0 { + let resp = serialise_frame(&Frame { + flags: FLAG_KEEPALIVE, + conn_id: frame.conn_id, + addr: None, + payload: vec![], + }); + let _ = tx.send(resp).await; + return; + } + + if frame.flags & FLAG_DATA == 0 { + return; + } + + let Some(ref dst) = frame.addr else { + let _ = tx.send(serialise_err(frame.conn_id)).await; + return; + }; + + // Block QUIC (UDP 443) and DNS (UDP 53) from udpgw: + // - QUIC: forces browsers to fall back to TCP/HTTP2 which is much + // faster through the batch tunnel pipeline. + // - DNS: let tun2proxy's virtual DNS / SOCKS5 UDP associate handle + // it instead — more reliable on the per-session path. + // VoIP (Telegram, Meet) still flows through udpgw normally. + let dst_port = match dst { + DstAddr::V4(_, p) | DstAddr::V6(_, p) | DstAddr::Domain(_, p) => *p, + }; + if dst_port == 443 || dst_port == 53 { + let _ = tx.send(serialise_err(frame.conn_id)).await; + return; + } + + let dst_addr = match dst.to_socket_addr() { + Ok(a) => a, + Err(e) => { + tracing::debug!("udpgw resolve failed: {}", e); + let _ = tx.send(serialise_err(frame.conn_id)).await; + return; + } + }; + + let Some(sock) = get_or_create_socket(frame.conn_id, &dst_addr, dst, tx, sockets).await else { + let _ = tx.send(serialise_err(frame.conn_id)).await; + return; + }; + + // Send the datagram. Response comes asynchronously via the reader task. + if let Err(e) = sock.send(&frame.payload).await { + tracing::debug!("udpgw send to {} failed: {}", dst_addr, e); + let _ = tx.send(serialise_err(frame.conn_id)).await; + } +} + +fn serialise_err(conn_id: u16) -> Vec { + serialise_frame(&Frame { + flags: FLAG_ERR, + conn_id, + addr: None, + payload: vec![], + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn keepalive_round_trip() { + let frame = Frame { flags: FLAG_KEEPALIVE, conn_id: 42, addr: None, payload: vec![] }; + let bytes = serialise_frame(&frame); + let (parsed, consumed) = try_parse_frame(&bytes).unwrap().unwrap(); + assert_eq!(consumed, bytes.len()); + assert_eq!(parsed.flags, FLAG_KEEPALIVE); + assert_eq!(parsed.conn_id, 42); + assert!(parsed.addr.is_none()); + assert!(parsed.payload.is_empty()); + } + + #[test] + fn data_ipv4_round_trip() { + let frame = Frame { + flags: FLAG_DATA, + conn_id: 7, + addr: Some(DstAddr::V4(Ipv4Addr::new(8, 8, 8, 8), 53)), + payload: vec![1, 2, 3, 4], + }; + let bytes = serialise_frame(&frame); + let (parsed, consumed) = try_parse_frame(&bytes).unwrap().unwrap(); + assert_eq!(consumed, bytes.len()); + assert_eq!(parsed.flags, FLAG_DATA); + assert_eq!(parsed.conn_id, 7); + assert_eq!(parsed.payload, vec![1, 2, 3, 4]); + match parsed.addr.unwrap() { + DstAddr::V4(ip, port) => { + assert_eq!(ip, Ipv4Addr::new(8, 8, 8, 8)); + assert_eq!(port, 53); + } + _ => panic!("expected IPv4"), + } + } + + #[test] + fn data_ipv6_round_trip() { + let frame = Frame { + flags: FLAG_DATA, + conn_id: 100, + addr: Some(DstAddr::V6(Ipv6Addr::LOCALHOST, 443)), + payload: b"hello".to_vec(), + }; + let bytes = serialise_frame(&frame); + let (parsed, _) = try_parse_frame(&bytes).unwrap().unwrap(); + assert_eq!(parsed.conn_id, 100); + match parsed.addr.unwrap() { + DstAddr::V6(ip, port) => { + assert_eq!(ip, Ipv6Addr::LOCALHOST); + assert_eq!(port, 443); + } + _ => panic!("expected IPv6"), + } + } + + #[test] + fn data_domain_round_trip() { + let frame = Frame { + flags: FLAG_DATA, + conn_id: 5, + addr: Some(DstAddr::Domain("example.com".into(), 80)), + payload: b"GET /".to_vec(), + }; + let bytes = serialise_frame(&frame); + let (parsed, _) = try_parse_frame(&bytes).unwrap().unwrap(); + match parsed.addr.unwrap() { + DstAddr::Domain(name, port) => { + assert_eq!(name, "example.com"); + assert_eq!(port, 80); + } + _ => panic!("expected Domain"), + } + } + + #[test] + fn err_frame_round_trip() { + let bytes = serialise_err(99); + let (parsed, _) = try_parse_frame(&bytes).unwrap().unwrap(); + assert_eq!(parsed.flags, FLAG_ERR); + assert_eq!(parsed.conn_id, 99); + } + + #[test] + fn partial_frame_returns_none() { + let frame = Frame { flags: FLAG_KEEPALIVE, conn_id: 1, addr: None, payload: vec![] }; + let bytes = serialise_frame(&frame); + // Give it only half the bytes. + assert!(try_parse_frame(&bytes[..bytes.len() / 2]).unwrap().is_none()); + } + + #[test] + fn two_frames_in_buffer() { + let f1 = serialise_frame(&Frame { flags: FLAG_KEEPALIVE, conn_id: 1, addr: None, payload: vec![] }); + let f2 = serialise_frame(&Frame { flags: FLAG_KEEPALIVE, conn_id: 2, addr: None, payload: vec![] }); + let mut buf = f1.clone(); + buf.extend_from_slice(&f2); + + let (p1, c1) = try_parse_frame(&buf).unwrap().unwrap(); + assert_eq!(p1.conn_id, 1); + let (p2, _) = try_parse_frame(&buf[c1..]).unwrap().unwrap(); + assert_eq!(p2.conn_id, 2); + } + + #[test] + fn is_udpgw_dest_works() { + assert!(is_udpgw_dest("198.18.0.1", 7300)); + assert!(!is_udpgw_dest("198.18.0.1", 80)); + assert!(!is_udpgw_dest("8.8.8.8", 7300)); + } +}