Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 91 additions & 64 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ use windows_sys::{
pub struct Adapter {
adapter: UnsafeHandle<wintun_raw::WINTUN_ADAPTER_HANDLE>,
pub(crate) wintun: Wintun,
guid: u128,
index: u32,
requested_guid: Option<u128>,
guid: OnceLock<u128>,
index: OnceLock<u32>,
luid: NET_LUID_LH,
}

Expand Down Expand Up @@ -60,7 +61,35 @@ impl Adapter {
}

pub fn get_guid(&self) -> u128 {
self.guid
if let Some(guid) = self.guid.get() {
return *guid;
}

let real_guid = match resolve_with_retry(|| crate::ffi::luid_to_guid(&self.luid)) {
Ok(g) => util::win_guid_to_u128(&g),
Err(_) => return self.requested_guid.unwrap_or(0),
};

if let Some(req) = self.requested_guid
&& req != real_guid
&& let (Ok(real_s), Ok(req_s), Ok((major, minor, build))) = (
util::guid_to_win_style_string(&GUID::from_u128(real_guid)),
util::guid_to_win_style_string(&GUID::from_u128(req)),
util::get_windows_version(),
)
{
log::warn!(
"Windows {major}.{minor}.{build}: an internal bug causes the GUID mismatch: Expected {req_s}, got {real_s}"
);
}

match self.guid.set(real_guid) {
Ok(()) => real_guid,
Err(_) => *self
.guid
.get()
.expect("guid should be initialized by this thread or another"),
}
}

/// Creates a new wintun adapter inside the name `name` with tunnel type `tunnel_type`
Expand All @@ -70,51 +99,32 @@ impl Adapter {
let name_utf16: Vec<_> = name.encode_utf16().chain(std::iter::once(0)).collect();
let tunnel_type_utf16: Vec<u16> = tunnel_type.encode_utf16().chain(std::iter::once(0)).collect();

let mut guid = match guid {
Some(guid) => guid,
None => {
let mut guid: GUID = unsafe { std::mem::zeroed() };
unsafe { windows_sys::Win32::System::Rpc::UuidCreate(&mut guid as *mut GUID) };
util::win_guid_to_u128(&guid)
}
};
let requested_guid = guid.unwrap_or_else(|| {
let mut guid: GUID = unsafe { std::mem::zeroed() };
unsafe { windows_sys::Win32::System::Rpc::UuidCreate(&mut guid as *mut GUID) };
util::win_guid_to_u128(&guid)
});

crate::log::set_default_logger_if_unset(wintun);

let guid_s: GUID = GUID::from_u128(guid);
let guid_s: GUID = GUID::from_u128(requested_guid);
let result = unsafe { wintun.WintunCreateAdapter(name_utf16.as_ptr(), tunnel_type_utf16.as_ptr(), &guid_s) };

if result.is_null() {
return crate::log::extract_wintun_log_error("WintunCreateAdapter failed")?;
}
let mut call = || -> Result<Arc<Adapter>, Error> {
let luid = crate::ffi::alias_to_luid(name)?;
let index = crate::ffi::luid_to_index(&luid)?;
let real_guid = util::win_guid_to_u128(&crate::ffi::luid_to_guid(&luid)?);
if guid != real_guid {
let real_guid_s = util::guid_to_win_style_string(&GUID::from_u128(real_guid))?;
let guid_s = util::guid_to_win_style_string(&GUID::from_u128(guid))?;
let (major, minor, build) = util::get_windows_version()?;
log::warn!(
"Windows {major}.{minor}.{build} internal bug cause the GUID mismatch: Expected {guid_s}, got {real_guid_s}"
);
guid = real_guid;
}
Ok(Arc::new(Adapter {
adapter: UnsafeHandle(result),
wintun: wintun.clone(),
guid,
index,
luid,
}))
};
match call() {
Ok(adapter) => Ok(adapter),
Err(e) => {
unsafe { wintun.WintunCloseAdapter(result) };
Err(e)
}
}

let mut luid: NET_LUID_LH = unsafe { std::mem::zeroed() };
unsafe { wintun.WintunGetAdapterLUID(result, &mut luid) };

Ok(Arc::new(Adapter {
adapter: UnsafeHandle(result),
wintun: wintun.clone(),
luid,
requested_guid: Some(requested_guid),
guid: OnceLock::new(),
index: OnceLock::new(),
}))
}

/// Attempts to open an existing wintun interface name `name`.
Expand All @@ -128,26 +138,18 @@ impl Adapter {
if result.is_null() {
return crate::log::extract_wintun_log_error("WintunOpenAdapter failed")?;
}
let call = || -> Result<Arc<Adapter>, Error> {
let luid = crate::ffi::alias_to_luid(name)?;
let index = crate::ffi::luid_to_index(&luid)?;
let guid = crate::ffi::luid_to_guid(&luid)?;
let guid = util::win_guid_to_u128(&guid);
Ok(Arc::new(Adapter {
adapter: UnsafeHandle(result),
wintun: wintun.clone(),
guid,
index,
luid,
}))
};
match call() {
Ok(adapter) => Ok(adapter),
Err(e) => {
unsafe { wintun.WintunCloseAdapter(result) };
Err(e)
}
}

let mut luid: NET_LUID_LH = unsafe { std::mem::zeroed() };
unsafe { wintun.WintunGetAdapterLUID(result, &mut luid) };

Ok(Arc::new(Adapter {
adapter: UnsafeHandle(result),
wintun: wintun.clone(),
luid,
requested_guid: None,
guid: OnceLock::new(),
index: OnceLock::new(),
}))
}

/// Delete an adapter, consuming it in the process
Expand Down Expand Up @@ -214,7 +216,11 @@ impl Adapter {
/// Returns the Win32 interface index of this adapter. Useful for specifying the interface
/// when executing `netsh interface ip` commands
pub fn get_adapter_index(&self) -> Result<u32, Error> {
Ok(self.index)
if let Some(idx) = self.index.get() {
return Ok(*idx);
}
let idx = resolve_with_retry(|| crate::ffi::luid_to_index(&self.luid))?;
Ok(*self.index.get_or_init(|| idx))
}

/// Sets the IP address for this adapter, using command `netsh`.
Expand Down Expand Up @@ -312,7 +318,7 @@ impl Adapter {

/// Returns the IP addresses of this adapter, including IPv4 and IPv6 addresses
pub fn get_addresses(&self) -> Result<Vec<IpAddr>, Error> {
let name = util::guid_to_win_style_string(&GUID::from_u128(self.guid))?;
let name = util::guid_to_win_style_string(&GUID::from_u128(self.get_guid()))?;

let mut adapter_addresses = vec![];

Expand Down Expand Up @@ -345,7 +351,7 @@ impl Adapter {

/// Returns the gateway addresses of this adapter, including IPv4 and IPv6 addresses
pub fn get_gateways(&self) -> Result<Vec<IpAddr>, Error> {
let name = util::guid_to_win_style_string(&GUID::from_u128(self.guid))?;
let name = util::guid_to_win_style_string(&GUID::from_u128(self.get_guid()))?;
let mut gateways = vec![];
util::get_adapters_addresses(|adapter| {
let name_iter = match unsafe { util::win_pstr_to_string(adapter.AdapterName) } {
Expand Down Expand Up @@ -375,7 +381,7 @@ impl Adapter {

/// Returns the subnet mask of the given address
pub fn get_netmask_of_address(&self, target_address: &IpAddr) -> Result<IpAddr, Error> {
let name = util::guid_to_win_style_string(&GUID::from_u128(self.guid))?;
let name = util::guid_to_win_style_string(&GUID::from_u128(self.get_guid()))?;
let mut subnet_mask = None;
util::get_adapters_addresses(|adapter| {
let name_iter = match unsafe { util::win_pstr_to_string(adapter.AdapterName) } {
Expand Down Expand Up @@ -599,3 +605,24 @@ pub(crate) fn delete_adapter_info_from_reg(dev_name: &str) -> std::io::Result<()

Ok(())
}

fn resolve_with_retry<T>(mut f: impl FnMut() -> std::io::Result<T>) -> std::io::Result<T> {
const ERROR_NOT_FOUND: i32 = 1168;
const NSI_RETRY_ATTEMPTS: u32 = 3;
const NSI_RETRY_DELAY_MS: u64 = 25;

for attempt in 1..=NSI_RETRY_ATTEMPTS {
match f() {
Ok(v) => return Ok(v),
Err(e) if e.raw_os_error() == Some(ERROR_NOT_FOUND) => {
if attempt == NSI_RETRY_ATTEMPTS {
return Err(e);
}
log::warn!("NSI race, retry {attempt}/{NSI_RETRY_ATTEMPTS}");
std::thread::sleep(std::time::Duration::from_millis(NSI_RETRY_DELAY_MS));
}
Err(e) => return Err(e),
}
}
unreachable!();
}
11 changes: 1 addition & 10 deletions src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use windows_sys::Win32::NetworkManagement::IpHelper::{
ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias, ConvertInterfaceLuidToGuid, ConvertInterfaceLuidToIndex,
ConvertInterfaceLuidToAlias, ConvertInterfaceLuidToGuid, ConvertInterfaceLuidToIndex,
};
use windows_sys::Win32::NetworkManagement::Ndis::{IF_MAX_STRING_SIZE, NET_LUID_LH};
use windows_sys::core::GUID;
Expand All @@ -14,15 +14,6 @@ pub fn luid_to_alias(luid: &NET_LUID_LH) -> std::io::Result<String> {
Ok(crate::util::decode_utf16(&r))
}

pub fn alias_to_luid(alias: &str) -> std::io::Result<NET_LUID_LH> {
let alias = alias.encode_utf16().chain(std::iter::once(0)).collect::<Vec<_>>();
let mut luid = unsafe { std::mem::zeroed() };

match unsafe { ConvertInterfaceAliasToLuid(alias.as_ptr(), &mut luid) } {
0 => Ok(luid),
err => Err(std::io::Error::from_raw_os_error(err as _)),
}
}
pub fn luid_to_index(luid: &NET_LUID_LH) -> std::io::Result<u32> {
let mut index = 0;

Expand Down
Loading