diff --git a/vpn_slice/__main__.py b/vpn_slice/__main__.py index 3768ef7..cbbbbc8 100755 --- a/vpn_slice/__main__.py +++ b/vpn_slice/__main__.py @@ -25,7 +25,7 @@ def get_default_providers(): DNSPythonProvider = None if platform.startswith('linux'): - from .linux import CheckTunDevProvider, Iproute2Provider, IptablesProvider, ProcfsProvider + from .linux import CheckTunDevProvider, Iproute2Provider, IptablesProvider, ProcfsProvider, LinuxSplitDNSProvider from .posix import DigProvider, PosixHostsFileProvider return dict( process = ProcfsProvider, @@ -34,6 +34,7 @@ def get_default_providers(): dns = DNSPythonProvider or DigProvider, hosts = PosixHostsFileProvider, prep = CheckTunDevProvider, + domain_vpn_dns = LinuxSplitDNSProvider, ) elif platform.startswith('darwin'): from platform import release @@ -150,7 +151,7 @@ def do_disconnect(env, args): if args.vpn_domains is not None: try: - providers.domain_vpn_dns.deconfigure_domain_vpn_dns(args.vpn_domains, env.dns) + providers.domain_vpn_dns.deconfigure_domain_vpn_dns(args.vpn_domains, env.dns, env.tundev) except OSError: print("WARNING: failed to deconfigure domains vpn dns", file=stderr) @@ -244,7 +245,11 @@ def do_connect(env, args): if 'domain_vpn_dns' not in providers: print("WARNING: no split dns provider available; can't split dns", file=stderr) else: - providers.domain_vpn_dns.configure_domain_vpn_dns(args.vpn_domains, env.dns) + try: + providers.domain_vpn_dns.configure_domain_vpn_dns(args.vpn_domains, env.dns, env.tundev) + print(f"Configured split DNS for domains {' '.join(args.vpn_domains)} to use VPN DNS servers.", file=stderr) + except OSError as e: + print(f"WARNING: Failed to configure split DNS: {e}", file=stderr) def do_post_connect(env, args): diff --git a/vpn_slice/linux.py b/vpn_slice/linux.py index fa3a325..5fa9be1 100644 --- a/vpn_slice/linux.py +++ b/vpn_slice/linux.py @@ -1,9 +1,10 @@ import os import stat import subprocess +import sys from .posix import PosixProcessProvider -from .provider import FirewallProvider, RouteProvider, TunnelPrepProvider +from .provider import FirewallProvider, RouteProvider, SplitDNSProvider, TunnelPrepProvider from .util import get_executable @@ -109,3 +110,50 @@ def create_tunnel(self): def prepare_tunnel(self): if not os.access('/dev/net/tun', os.R_OK | os.W_OK): raise OSError("can't read and write /dev/net/tun") + +class LinuxSplitDNSProvider(SplitDNSProvider): + def configure_domain_vpn_dns(self, domains, nameservers, dev): + try: + status = subprocess.check_output( + ['systemctl', 'is-active', 'systemd-resolved'], + universal_newlines=True, + stderr=subprocess.STDOUT + ).strip() + except subprocess.CalledProcessError as e: + raise OSError("systemd-resolved is not active; cannot configure DNS") from e + + if status != 'active': + raise OSError("systemd-resolved is not active; cannot configure DNS") + + try: + with open('/etc/resolv.conf', 'r') as f: + if 'nameserver 127.0.0.53' not in f.read(): + print("/etc/resolv.conf does not contain 127.0.0.53, are you sure you are using systemd-resolved?") + except FileNotFoundError: + raise OSError("/etc/resolv.conf not found") + + resolvectl = get_executable('/sbin/resolvectl') + try: + # Configure nameservers + subprocess.check_call([resolvectl, 'dns', dev] + [str(ns) for ns in nameservers]) + except subprocess.CalledProcessError as e: + raise OSError(f"Failed to configure DNS: {e}") + try: + # Configure search domains + subprocess.check_call([resolvectl, 'domain', dev] + [f"~{domain}" for domain in domains]) + except subprocess.CalledProcessError as e: + raise OSError(f"Failed to configure domain: {e}") + try: + # Remove default route + subprocess.check_call([resolvectl, 'default-route', dev, "false"]) + except subprocess.CalledProcessError as e: + raise OSError(f"Failed to configure default route: {e}") + + def deconfigure_domain_vpn_dns(self, domains, nameservers, dev): + resolvectl = get_executable('/sbin/resolvectl') + try: + subprocess.check_call([resolvectl, 'revert', dev]) + except subprocess.CalledProcessError as e: + raise OSError(f"Failed to revert DNS configuration: {e}") + + \ No newline at end of file diff --git a/vpn_slice/mac.py b/vpn_slice/mac.py index c9176b0..94db024 100644 --- a/vpn_slice/mac.py +++ b/vpn_slice/mac.py @@ -115,7 +115,7 @@ def add_address(self, device, address): class MacSplitDNSProvider(SplitDNSProvider): - def configure_domain_vpn_dns(self, domains, nameservers): + def configure_domain_vpn_dns(self, domains, nameservers, dev): if not os.path.exists('/etc/resolver'): os.makedirs('/etc/resolver') for domain in domains: @@ -124,7 +124,7 @@ def configure_domain_vpn_dns(self, domains, nameservers): for nameserver in nameservers: resolver_file.write(f"nameserver {nameserver}\n") - def deconfigure_domain_vpn_dns(self, domains, nameservers): + def deconfigure_domain_vpn_dns(self, domains, nameservers, dev): for domain in domains: resolver_file_name = f"/etc/resolver/{domain}" if os.path.exists(resolver_file_name): diff --git a/vpn_slice/provider.py b/vpn_slice/provider.py index c9c43f7..4b3236f 100644 --- a/vpn_slice/provider.py +++ b/vpn_slice/provider.py @@ -160,14 +160,14 @@ def prepare_tunnel(self): """ class SplitDNSProvider: - def configure_domain_vpn_dns(self, domains, nameservers): + def configure_domain_vpn_dns(self, domains, nameservers, dev): """Configure domain vpn dns. Base class behavior is to do nothing. """ - def deconfigure_domain_vpn_dns(self, domains, nameservers): + def deconfigure_domain_vpn_dns(self, domains, nameservers, dev): """Remove domain vpn dns. Base class behavior is to do nothing.