diff --git a/cmd/discover.go b/cmd/discover.go index 1c40111..77c4673 100644 --- a/cmd/discover.go +++ b/cmd/discover.go @@ -10,7 +10,6 @@ import ( "os/signal" "syscall" - "github.com/libp2p/go-libp2p/core/host" "github.com/spf13/cobra" "github.com/substantialcattle5/sietch/internal/config" @@ -40,6 +39,7 @@ Example: port, _ := cmd.Flags().GetInt("port") verbose, _ := cmd.Flags().GetBool("verbose") vaultPath, _ := cmd.Flags().GetString("vault-path") + allAddresses, _ := cmd.Flags().GetBool("all-addresses") // If no vault path specified, use current directory if vaultPath == "" { @@ -72,7 +72,7 @@ Example: fmt.Printf("šŸ” Starting peer discovery with node ID: %s\n", host.ID().String()) if verbose { - displayHostAddresses(host) + discover.DisplayHostAddresses(host, allAddresses) } // Create a vault manager @@ -101,17 +101,11 @@ Example: defer func() { _ = discovery.Stop() }() // Run the discovery loop - return discover.RunDiscoveryLoop(ctx, host, syncService, peerChan, timeout, continuous) + return discover.RunDiscoveryLoop(ctx, host, syncService, peerChan, timeout, continuous, allAddresses) }, } -// displayHostAddresses prints the addresses the host is listening on -func displayHostAddresses(h host.Host) { - fmt.Println("Listening on:") - for _, addr := range h.Addrs() { - fmt.Printf(" %s/p2p/%s\n", addr, h.ID().String()) - } -} + func init() { rootCmd.AddCommand(discoverCmd) @@ -122,4 +116,5 @@ func init() { discoverCmd.Flags().IntP("port", "p", 0, "Port to use for libp2p (0 for random port)") discoverCmd.Flags().BoolP("verbose", "v", false, "Enable verbose output") discoverCmd.Flags().StringP("vault-path", "V", "", "Path to the vault directory (defaults to current directory)") + discoverCmd.Flags().Bool("all-addresses", false, "Show all network addresses including Docker, VPN, and virtual interfaces") } diff --git a/internal/discover/address_display.go b/internal/discover/address_display.go new file mode 100644 index 0000000..c991255 --- /dev/null +++ b/internal/discover/address_display.go @@ -0,0 +1,86 @@ +package discover + +import ( + "fmt" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" +) + +// DisplayHostAddresses prints the addresses the host is listening on with filtering +func DisplayHostAddresses(h host.Host, showAll bool) { + filter := NewAddressFilter(showAll) + filtered := filter.FilterAddresses(h.Addrs()) + + fmt.Println("Listening on:") + + if len(filtered) == 0 { + fmt.Println(" No suitable addresses found") + return + } + + for _, addr := range filtered { + if showAll { + // Show full multiaddr format with peer ID + fmt.Printf(" %s/p2p/%s", addr.Original, h.ID().String()) + if addr.Label != "" { + fmt.Printf(" %s", addr.Label) + } + fmt.Println() + } else { + // Show simplified format + fmt.Printf(" %s", addr.DisplayAddr) + if addr.Label != "" { + fmt.Printf(" %s", addr.Label) + } + fmt.Println() + } + } + + // Show summary of filtered addresses if any were hidden + if !showAll { + hiddenCount := filter.CountFilteredAddresses(h.Addrs()) + if hiddenCount > 0 { + fmt.Printf(" [+%d more, use --all-addresses to show]\n", hiddenCount) + } + } +} + +// DisplayPeerAddresses prints the addresses for a discovered peer with filtering +func DisplayPeerAddresses(peerInfo peer.AddrInfo, showAll bool) { + filter := NewAddressFilter(showAll) + filtered := filter.FilterAddresses(peerInfo.Addrs) + + fmt.Println(" Addresses:") + + if len(filtered) == 0 { + fmt.Println(" No suitable addresses found") + return + } + + for _, addr := range filtered { + if showAll { + // Show full multiaddr format + fmt.Printf(" - %s", addr.Original) + if addr.Label != "" { + fmt.Printf(" %s", addr.Label) + } + fmt.Println() + } else { + // Show simplified format + fmt.Printf(" - %s", addr.DisplayAddr) + if addr.Label != "" { + fmt.Printf(" %s", addr.Label) + } + fmt.Println() + } + } + + // Show summary of filtered addresses if any were hidden + if !showAll { + hiddenCount := filter.CountFilteredAddresses(peerInfo.Addrs) + if hiddenCount > 0 { + fmt.Printf(" [+%d more, use --all-addresses to show]\n", hiddenCount) + } + } +} \ No newline at end of file diff --git a/internal/discover/address_filter.go b/internal/discover/address_filter.go new file mode 100644 index 0000000..397c5d8 --- /dev/null +++ b/internal/discover/address_filter.go @@ -0,0 +1,292 @@ +package discover + +import ( + "net" + "regexp" + "sort" + "strings" + + "github.com/multiformats/go-multiaddr" +) + +// AddressPriority defines the priority levels for different address types +type AddressPriority int + +const ( + PriorityVirtual AddressPriority = 0 // Filtered out (Docker, VPN, etc.) + PriorityOther AddressPriority = 1 // Shown only with --all-addresses + PriorityIPv6 AddressPriority = 2 // Limited to 1 address + PriorityLocalhost AddressPriority = 3 // Always shown + PriorityLAN AddressPriority = 4 // Primary LAN addresses +) + +// FilteredAddress represents an address with its priority and display information +type FilteredAddress struct { + Original multiaddr.Multiaddr + Priority AddressPriority + DisplayAddr string + Label string +} + +// AddressFilter handles filtering and prioritizing network addresses +type AddressFilter struct { + showAll bool +} + +// NewAddressFilter creates a new address filter +func NewAddressFilter(showAll bool) *AddressFilter { + return &AddressFilter{showAll: showAll} +} + +// FilterAddresses filters and prioritizes a list of multiaddresses +func (af *AddressFilter) FilterAddresses(addrs []multiaddr.Multiaddr) []FilteredAddress { + var filtered []FilteredAddress + + for _, addr := range addrs { + if fa := af.categorizeAddress(addr); fa != nil { + filtered = append(filtered, *fa) + } + } + + // Sort by priority (highest first) + sort.Slice(filtered, func(i, j int) bool { + if filtered[i].Priority != filtered[j].Priority { + return filtered[i].Priority > filtered[j].Priority + } + // Secondary sort by address string for consistency + return filtered[i].DisplayAddr < filtered[j].DisplayAddr + }) + + if af.showAll { + return filtered + } + + return af.applyFiltering(filtered) +} + +// categorizeAddress determines the priority and display format for an address +func (af *AddressFilter) categorizeAddress(addr multiaddr.Multiaddr) *FilteredAddress { + // Extract IP and port from multiaddr + ip, port := af.extractIPAndPort(addr) + if ip == "" { + return nil + } + + fa := &FilteredAddress{ + Original: addr, + } + + // Check for localhost + if af.isLocalhost(ip) { + fa.Priority = PriorityLocalhost + fa.DisplayAddr = "localhost:" + port + fa.Label = "" + return fa + } + + // Check for virtual interfaces (Docker, VPN, etc.) + if af.isVirtualInterface(ip) { + fa.Priority = PriorityVirtual + fa.DisplayAddr = ip + ":" + port + fa.Label = af.getVirtualInterfaceLabel(ip) + return fa + } + + // Check for private LAN addresses + if af.isPrivateLAN(ip) { + fa.Priority = PriorityLAN + fa.DisplayAddr = ip + ":" + port + fa.Label = "(primary)" + return fa + } + + // Check for IPv6 + if af.isIPv6(ip) { + fa.Priority = PriorityIPv6 + fa.DisplayAddr = "[" + ip + "]:" + port + fa.Label = "" + return fa + } + + // Everything else + fa.Priority = PriorityOther + fa.DisplayAddr = ip + ":" + port + fa.Label = "" + return fa +} + +// applyFiltering applies the filtering rules when showAll is false +func (af *AddressFilter) applyFiltering(addresses []FilteredAddress) []FilteredAddress { + var result []FilteredAddress + ipv6Count := 0 + + for _, addr := range addresses { + switch addr.Priority { + case PriorityLAN, PriorityLocalhost: + // Always include LAN and localhost + result = append(result, addr) + case PriorityIPv6: + // Include only one IPv6 address + if ipv6Count == 0 { + result = append(result, addr) + ipv6Count++ + } + case PriorityVirtual, PriorityOther: + // Skip virtual and other addresses in filtered mode + continue + } + } + + // If we have no addresses, include at least one non-virtual address + if len(result) == 0 { + for _, addr := range addresses { + if addr.Priority > PriorityVirtual { + result = append(result, addr) + break + } + } + } + + return result +} + +// extractIPAndPort extracts IP and port from a multiaddr +func (af *AddressFilter) extractIPAndPort(addr multiaddr.Multiaddr) (string, string) { + var ip, port string + + multiaddr.ForEach(addr, func(c multiaddr.Component) bool { + switch c.Protocol().Code { + case multiaddr.P_IP4, multiaddr.P_IP6: + ip = c.Value() + case multiaddr.P_TCP, multiaddr.P_UDP: + port = c.Value() + } + return true + }) + + return ip, port +} + +// isLocalhost checks if an IP is localhost +func (af *AddressFilter) isLocalhost(ip string) bool { + return ip == "127.0.0.1" || ip == "::1" +} + +// isPrivateLAN checks if an IP is a private LAN address +func (af *AddressFilter) isPrivateLAN(ip string) bool { + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false + } + + // Check for private IPv4 ranges + private := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + } + + for _, cidr := range private { + _, network, err := net.ParseCIDR(cidr) + if err != nil { + continue + } + if network.Contains(parsedIP) { + // Exclude Docker ranges from LAN classification + if af.isDockerRange(ip) { + return false + } + return true + } + } + + return false +} + +// isVirtualInterface checks if an IP belongs to a virtual interface +func (af *AddressFilter) isVirtualInterface(ip string) bool { + return af.isDockerRange(ip) || af.isVPNRange(ip) +} + +// isDockerRange checks if an IP is in a Docker range +func (af *AddressFilter) isDockerRange(ip string) bool { + // Common Docker ranges + dockerRanges := []string{ + "172.17.0.0/16", // Default Docker bridge + "172.18.0.0/16", // Docker custom networks + "172.19.0.0/16", + "172.20.0.0/16", + "172.21.0.0/16", + "172.22.0.0/16", + "172.23.0.0/16", + "172.24.0.0/16", + "172.25.0.0/16", + } + + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false + } + + for _, cidr := range dockerRanges { + _, network, err := net.ParseCIDR(cidr) + if err != nil { + continue + } + if network.Contains(parsedIP) { + return true + } + } + + // Also check for Docker-like patterns in 192.168.x.1 ranges + if matched, _ := regexp.MatchString(`^192\.168\.(224|240|208|176|144|112|80|48|16)\.1$`, ip); matched { + return true + } + + return false +} + +// isVPNRange checks if an IP is in a VPN range +func (af *AddressFilter) isVPNRange(ip string) bool { + // Common VPN ranges - this is a basic implementation + // You might want to expand this based on your specific VPN software + vpnPatterns := []string{ + `^10\.8\.0\.`, // OpenVPN default + `^10\.9\.0\.`, // WireGuard common + `^192\.168\.122\.`, // libvirt/KVM default + } + + for _, pattern := range vpnPatterns { + if matched, _ := regexp.MatchString(pattern, ip); matched { + return true + } + } + + return false +} + +// isIPv6 checks if an IP is IPv6 +func (af *AddressFilter) isIPv6(ip string) bool { + return strings.Contains(ip, ":") +} + +// getVirtualInterfaceLabel returns a label for virtual interfaces +func (af *AddressFilter) getVirtualInterfaceLabel(ip string) string { + if af.isDockerRange(ip) { + return "(docker)" + } + if af.isVPNRange(ip) { + return "(vpn)" + } + return "(virtual)" +} + +// CountFilteredAddresses returns the count of addresses that would be filtered out +func (af *AddressFilter) CountFilteredAddresses(addrs []multiaddr.Multiaddr) int { + if af.showAll { + return 0 + } + + filtered := af.FilterAddresses(addrs) + return len(addrs) - len(filtered) +} \ No newline at end of file diff --git a/internal/discover/address_filter_test.go b/internal/discover/address_filter_test.go new file mode 100644 index 0000000..bf4bdff --- /dev/null +++ b/internal/discover/address_filter_test.go @@ -0,0 +1,111 @@ +package discover + +import ( + "testing" + + "github.com/multiformats/go-multiaddr" +) + +func TestAddressFilter(t *testing.T) { + // Create test addresses + testAddrs := []string{ + "/ip4/127.0.0.1/tcp/39295", // localhost + "/ip4/192.168.0.133/tcp/39295", // LAN + "/ip4/172.17.0.1/tcp/39295", // Docker + "/ip4/172.21.0.1/tcp/39295", // Docker + "/ip4/10.8.0.1/tcp/39295", // VPN + "/ip6/::1/tcp/36104", // IPv6 localhost + "/ip6/2001:db8::1/tcp/36104", // IPv6 + "/ip4/8.8.8.8/tcp/39295", // Public IP + } + + var addrs []multiaddr.Multiaddr + for _, addrStr := range testAddrs { + addr, err := multiaddr.NewMultiaddr(addrStr) + if err != nil { + t.Fatalf("Failed to create multiaddr %s: %v", addrStr, err) + } + addrs = append(addrs, addr) + } + + t.Run("ShowAll=true", func(t *testing.T) { + filter := NewAddressFilter(true) + filtered := filter.FilterAddresses(addrs) + + // Should return all addresses + if len(filtered) != len(addrs) { + t.Errorf("Expected %d addresses, got %d", len(addrs), len(filtered)) + } + }) + + t.Run("ShowAll=false", func(t *testing.T) { + filter := NewAddressFilter(false) + filtered := filter.FilterAddresses(addrs) + + // Should filter out Docker and VPN addresses + expectedCount := 4 // localhost, LAN, IPv6 localhost (limited to 1), public IP + if len(filtered) > expectedCount { + t.Errorf("Expected at most %d addresses, got %d", expectedCount, len(filtered)) + } + + // Check that localhost and LAN are included + hasLocalhost := false + hasLAN := false + for _, addr := range filtered { + if addr.DisplayAddr == "localhost:39295" { + hasLocalhost = true + } + if addr.DisplayAddr == "192.168.0.133:39295" && addr.Label == "(primary)" { + hasLAN = true + } + } + + if !hasLocalhost { + t.Error("Expected localhost address to be included") + } + if !hasLAN { + t.Error("Expected LAN address to be included with (primary) label") + } + }) + + t.Run("CountFilteredAddresses", func(t *testing.T) { + filter := NewAddressFilter(false) + count := filter.CountFilteredAddresses(addrs) + + // Should count Docker and VPN addresses as filtered + if count < 3 { // At least Docker and VPN addresses + t.Errorf("Expected at least 3 filtered addresses, got %d", count) + } + }) +} + +func TestAddressPrioritization(t *testing.T) { + filter := NewAddressFilter(false) + + // Test localhost detection + localhostAddr, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/39295") + fa := filter.categorizeAddress(localhostAddr) + if fa.Priority != PriorityLocalhost { + t.Errorf("Expected localhost priority, got %d", fa.Priority) + } + if fa.DisplayAddr != "localhost:39295" { + t.Errorf("Expected localhost:39295, got %s", fa.DisplayAddr) + } + + // Test LAN detection + lanAddr, _ := multiaddr.NewMultiaddr("/ip4/192.168.0.133/tcp/39295") + fa = filter.categorizeAddress(lanAddr) + if fa.Priority != PriorityLAN { + t.Errorf("Expected LAN priority, got %d", fa.Priority) + } + if fa.Label != "(primary)" { + t.Errorf("Expected (primary) label, got %s", fa.Label) + } + + // Test Docker detection + dockerAddr, _ := multiaddr.NewMultiaddr("/ip4/172.17.0.1/tcp/39295") + fa = filter.categorizeAddress(dockerAddr) + if fa.Priority != PriorityVirtual { + t.Errorf("Expected virtual priority for Docker, got %d", fa.Priority) + } +} \ No newline at end of file diff --git a/internal/discover/discovery_util.go b/internal/discover/discovery_util.go index 33d8276..e86d35c 100644 --- a/internal/discover/discovery_util.go +++ b/internal/discover/discovery_util.go @@ -65,7 +65,7 @@ func SetupDiscovery(ctx context.Context, h host.Host) (*p2p.MDNSDiscovery, <-cha // runDiscoveryLoop processes discovered peers until timeout or interrupted func RunDiscoveryLoop(ctx context.Context, h host.Host, syncService *p2p.SyncService, - peerChan <-chan peer.AddrInfo, timeout int, continuous bool, + peerChan <-chan peer.AddrInfo, timeout int, continuous bool, allAddresses bool, ) error { var timeoutChan <-chan time.Time if !continuous { @@ -93,7 +93,7 @@ func RunDiscoveryLoop(ctx context.Context, h host.Host, syncService *p2p.SyncSer discoveredPeers[p.ID.String()] = true peerCount++ - handleDiscoveredPeer(ctx, h, syncService, p, peerCount) + handleDiscoveredPeer(ctx, h, syncService, p, peerCount, allAddresses) case <-timeoutChan: fmt.Printf("\nāŒ› Discovery timeout reached after %d seconds.\n", timeout) @@ -117,14 +117,11 @@ func RunDiscoveryLoop(ctx context.Context, h host.Host, syncService *p2p.SyncSer // handleDiscoveredPeer processes a newly discovered peer func handleDiscoveredPeer(ctx context.Context, h host.Host, syncService *p2p.SyncService, - p peer.AddrInfo, peerCount int, + p peer.AddrInfo, peerCount int, allAddresses bool, ) { fmt.Printf("āœ… Discovered peer #%d\n", peerCount) fmt.Printf(" ID: %s\n", p.ID.String()) - fmt.Println(" Addresses:") - for _, addr := range p.Addrs { - fmt.Printf(" - %s\n", addr.String()) - } + DisplayPeerAddresses(p, allAddresses) fmt.Printf(" Connecting and exchanging keys... ") diff --git a/internal/discover/handle_discovered_peer_test.go b/internal/discover/handle_discovered_peer_test.go index 735e959..79c3f94 100644 --- a/internal/discover/handle_discovered_peer_test.go +++ b/internal/discover/handle_discovered_peer_test.go @@ -38,9 +38,9 @@ func TestHandleDiscoveredPeerExercisesPaths(t *testing.T) { p := peer.AddrInfo{ID: h.ID(), Addrs: h.Addrs()} // Case 1: peer not present in trustedPeers -> AddTrustedPeer will fail - handleDiscoveredPeer(context.Background(), h, svc, p, 1) + handleDiscoveredPeer(context.Background(), h, svc, p, 1, false) // We cannot access unexported fields of SyncService from here; ensure // the function returns without panic when called a second time. - handleDiscoveredPeer(context.Background(), h, svc, p, 2) + handleDiscoveredPeer(context.Background(), h, svc, p, 2, false) }