Skip to content

Fallible Wasm Module for testing #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/wasm/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use tokio::io::AsyncRead;

// A trait for a decoder, developers should implement this trait and pass it to _read_from_outbound
pub trait Decoder {
fn decode(&self, input: &[u8], output: &mut [u8]) -> Result<u32, anyhow::Error>;
fn decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<u32, anyhow::Error>;
}

// A default decoder that does just copy + paste
pub struct DefaultDecoder;

impl Decoder for DefaultDecoder {
fn decode(&self, input: &[u8], output: &mut [u8]) -> Result<u32, anyhow::Error> {
fn decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<u32, anyhow::Error> {
let len = input.len();
output[..len].copy_from_slice(&input[..len]);
Ok(len as u32)
Expand Down
2 changes: 0 additions & 2 deletions crates/wasm/src/dialer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ impl Dialer {

pub fn dial(&mut self) -> Result<i32, anyhow::Error> {
info!("[WASM] running in dial func...");

// FIXME: hardcoded the filename for now, make it a config later
let fd: i32 = self.tcp_connect()?;

if fd < 0 {
Expand Down
4 changes: 2 additions & 2 deletions crates/wasm/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use tokio::io::AsyncWrite;

// A trait for a encoder, developers should implement this trait and pass it to _write_to_outbound
pub trait Encoder {
fn encode(&self, input: &[u8], output: &mut [u8]) -> Result<u32, anyhow::Error>;
fn encode(&mut self, input: &[u8], output: &mut [u8]) -> Result<u32, anyhow::Error>;
}

// A default encoder that does just copy + paste
pub struct DefaultEncoder;

impl Encoder for DefaultEncoder {
fn encode(&self, input: &[u8], output: &mut [u8]) -> Result<u32, anyhow::Error> {
fn encode(&mut self, input: &[u8], output: &mut [u8]) -> Result<u32, anyhow::Error> {
let len = input.len();
output[..len].copy_from_slice(&input[..len]);
Ok(len as u32)
Expand Down
5 changes: 5 additions & 0 deletions examples/water_bins/fallible/.cargo/config
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[build]
target = "wasm32-wasi"

[target.wasm32-wasi]
rustflags = [ "--cfg", "tokio_unstable"]
31 changes: 31 additions & 0 deletions examples/water_bins/fallible/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[package]
name = "fallible"
version = "0.1.0"
authors.workspace = true
description.workspace = true
edition.workspace = true
publish = false

[lib]
name = "fallible"
path = "src/lib.rs"
crate-type = ["cdylib"]

[dependencies]
tokio = { version = "1.24.2", default-features = false, features = ["net", "rt", "macros", "io-util", "io-std", "time", "sync"] }
tokio-util = { version = "0.7.1", features = ["codec"] }

serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.107"
bincode = "1.3"

anyhow = "1.0.7"
tracing = "0.1"
tracing-subscriber = "0.3.17"
toml = "0.5.9"
lazy_static = "1.4"
url = { version = "2.2.2", features = ["serde"] }
libc = "0.2.147"

# water wasm lib import
water-wasm = { path = "../../../crates/wasm/", version = "0.1.0" }
297 changes: 297 additions & 0 deletions examples/water_bins/fallible/src/async_socks5_listener.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
use super::*;

use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
};

use std::net::{SocketAddr, ToSocketAddrs};

// ----------------------- Listener methods -----------------------
#[export_name = "v1_listen"]
fn listen() {
wrapper().unwrap();
}

fn _listener_creation() -> Result<i32, std::io::Error> {
let global_conn = match DIALER.lock() {
Ok(conf) => conf,
Err(e) => {
eprintln!("[WASM] > ERROR: {}", e);
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"failed to lock config",
));
}
};

// FIXME: hardcoded the filename for now, make it a config later
let stream = StreamConfigV1::init(
global_conn.config.local_address.clone(),
global_conn.config.local_port,
"LISTEN".to_string(),
);

let encoded: Vec<u8> = bincode::serialize(&stream).expect("Failed to serialize");

let address = encoded.as_ptr() as u32;
let size = encoded.len() as u32;

let fd = unsafe { create_listen(address, size) };

if fd < 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"failed to create listener",
));
}

info!(
"[WASM] ready to start listening at {}:{}",
global_conn.config.local_address, global_conn.config.local_port
);

Ok(fd)
}

#[tokio::main(flavor = "current_thread")]
async fn wrapper() -> std::io::Result<()> {
let fd = _listener_creation().unwrap();

// Set up pre-established listening socket.
let standard = unsafe { std::net::TcpListener::from_raw_fd(fd) };
// standard.set_nonblocking(true).unwrap();
let listener = TcpListener::from_std(standard)?;

info!("[WASM] Starting to listen...");

loop {
// Accept new sockets in a loop.
let socket = match listener.accept().await {
Ok(s) => s.0,
Err(e) => {
eprintln!("[WASM] > ERROR: {}", e);
continue;
}
};

// Spawn a background task for each new connection.
tokio::spawn(async move {
eprintln!("[WASM] > CONNECTED");
match handle_incoming(socket).await {
Ok(()) => eprintln!("[WASM] > DISCONNECTED"),
Err(e) => eprintln!("[WASM] > ERROR: {}", e),
}
});
}
}

// SS handle incoming connections
async fn handle_incoming(mut stream: TcpStream) -> std::io::Result<()> {
let mut buffer = [0; 512];

// Read the SOCKS5 greeting
let nbytes = stream
.read(&mut buffer)
.await
.expect("Failed to read from stream");

println!("Received {} bytes: {:?}", nbytes, buffer[..nbytes].to_vec());

if nbytes < 2 || buffer[0] != 0x05 {
eprintln!("Not a SOCKS5 request");
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Not a SOCKS5 request",
));
}

let nmethods = buffer[1] as usize;
if nbytes < 2 + nmethods {
eprintln!("Incomplete SOCKS5 greeting");
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Incomplete SOCKS5 greeting",
));
}

// For simplicity, always use "NO AUTHENTICATION REQUIRED"
stream
.write_all(&[0x05, 0x00])
.await
.expect("Failed to write to stream");

// Read the actual request
let nbytes = stream
.read(&mut buffer)
.await
.expect("Failed to read from stream");

println!("Received {} bytes: {:?}", nbytes, buffer[..nbytes].to_vec());

if nbytes < 7 || buffer[0] != 0x05 || buffer[1] != 0x01 {
println!("Invalid SOCKS5 request");
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid SOCKS5 request",
));
}

// Extract address and port
let addr: SocketAddr = match buffer[3] {
0x01 => {
// IPv4
if nbytes < 10 {
eprintln!("Incomplete request for IPv4 address");
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Incomplete request for IPv4 address",
));
}
let addr = std::net::Ipv4Addr::new(buffer[4], buffer[5], buffer[6], buffer[7]);
let port = u16::from_be_bytes([buffer[8], buffer[9]]);
SocketAddr::from((addr, port))
}
0x03 => {
// Domain name
let domain_length = buffer[4] as usize;
if nbytes < domain_length + 5 {
eprintln!("Incomplete request for domain name");
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Incomplete request for domain name",
));
}
let domain =
std::str::from_utf8(&buffer[5..5 + domain_length]).expect("Invalid domain string");

println!("Domain: {}", domain);

let port =
u16::from_be_bytes([buffer[5 + domain_length], buffer[5 + domain_length + 1]]);

println!("Port: {}", port);

let domain_with_port = format!("{}:443", domain); // Assuming HTTPS

// domain.to_socket_addrs().unwrap().next().unwrap()
match domain_with_port.to_socket_addrs() {
Ok(mut addrs) => match addrs.next() {
Some(addr) => addr,
None => {
eprintln!("Domain resolved, but no addresses found for {}", domain);
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Domain resolved, but no addresses found for {}", domain),
));
}
},
Err(e) => {
eprintln!("Failed to resolve domain {}: {}", domain, e);
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Failed to resolve domain {}: {}", domain, e),
));
}
}
}
_ => {
eprintln!("Address type not supported");
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Address type not supported",
));
}
};

// NOTE: create a new Dialer to dial any target address as it wants to
// Add more features later -- connect to target thru rules (direct / server)

// Connect to target address
let mut tcp_dialer = Dialer::new();
tcp_dialer.config.remote_address = addr.ip().to_string();
tcp_dialer.config.remote_port = addr.port() as u32;

let _tcp_fd = tcp_dialer.dial().expect("Failed to dial");

let target_stream = match tcp_dialer.file_conn.outbound_conn.file.unwrap() {
ConnStream::TcpStream(s) => s,
_ => {
eprintln!("Failed to get outbound tcp stream");
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Failed to get outbound tcp stream",
));
}
};

target_stream
.set_nonblocking(true)
.expect("Failed to set non-blocking");

let target_stream =
TcpStream::from_std(target_stream).expect("Failed to convert to tokio stream");

// Construct the response based on the target address
let response = match addr {
SocketAddr::V4(a) => {
let mut r = vec![0x05, 0x00, 0x00, 0x01];
r.extend_from_slice(&a.ip().octets());
r.extend_from_slice(&a.port().to_be_bytes());
r
}
SocketAddr::V6(a) => {
let mut r = vec![0x05, 0x00, 0x00, 0x04];
r.extend_from_slice(&a.ip().octets());
r.extend_from_slice(&a.port().to_be_bytes());
r
}
};

stream
.write_all(&response)
.await
.expect("Failed to write to stream");

let (mut client_read, mut client_write) = tokio::io::split(stream);
let (mut target_read, mut target_write) = tokio::io::split(target_stream);

let client_to_target = async move {
let mut buffer = vec![0; 4096];
loop {
match client_read.read(&mut buffer).await {
Ok(0) => {
break;
}
Ok(n) => {
if (target_write.write_all(&buffer[0..n]).await).is_err() {
break;
}
}
Err(_) => break,
}
}
};

let target_to_client = async move {
let mut buffer = vec![0; 4096];
loop {
match target_read.read(&mut buffer).await {
Ok(0) => {
break;
}
Ok(n) => {
if (client_write.write_all(&buffer[0..n]).await).is_err() {
break;
}
}
Err(_) => break,
}
}
};

// Run both handlers concurrently
tokio::join!(client_to_target, target_to_client);

Ok(())
}
Loading