diff --git a/lib/common_test.go b/lib/common_test.go new file mode 100644 index 00000000000..8879ade0eae --- /dev/null +++ b/lib/common_test.go @@ -0,0 +1,148 @@ +package lib + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetRemoteURLContent(t *testing.T) { + // Success + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("hello world")) + })) + defer server.Close() + + content, err := GetRemoteURLContent(server.URL) + if err != nil { + t.Errorf("GetRemoteURLContent error = %v", err) + } + if string(content) != "hello world" { + t.Errorf("expected 'hello world', got %q", string(content)) + } + + // Non-200 status + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server2.Close() + + _, err = GetRemoteURLContent(server2.URL) + if err == nil { + t.Error("expected error for non-200 status") + } + + // Invalid URL + _, err = GetRemoteURLContent("http://invalid-host-that-does-not-exist.example.com") + if err == nil { + t.Error("expected error for invalid URL") + } +} + +func TestGetRemoteURLReader(t *testing.T) { + // Success + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test content")) + })) + defer server.Close() + + reader, err := GetRemoteURLReader(server.URL) + if err != nil { + t.Errorf("GetRemoteURLReader error = %v", err) + } + if reader != nil { + reader.Close() + } + + // Non-200 status + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server2.Close() + + _, err = GetRemoteURLReader(server2.URL) + if err == nil { + t.Error("expected error for non-200 status") + } + + // Invalid URL + _, err = GetRemoteURLReader("http://invalid-host-that-does-not-exist.example.com") + if err == nil { + t.Error("expected error for invalid URL") + } +} + +func TestGetIgnoreIPType(t *testing.T) { + // IPv4 -> IgnoreIPv6 + opt := GetIgnoreIPType(IPv4) + if opt == nil { + t.Fatal("expected non-nil option for IPv4") + } + if opt() != IPv6 { + t.Error("expected IgnoreIPv6 for IPv4 input") + } + + // IPv6 -> IgnoreIPv4 + opt = GetIgnoreIPType(IPv6) + if opt == nil { + t.Fatal("expected non-nil option for IPv6") + } + if opt() != IPv4 { + t.Error("expected IgnoreIPv4 for IPv6 input") + } + + // Other -> nil + opt = GetIgnoreIPType(IPType("other")) + if opt != nil { + t.Error("expected nil option for unknown IP type") + } + + // Empty -> nil + opt = GetIgnoreIPType(IPType("")) + if opt != nil { + t.Error("expected nil option for empty IP type") + } +} + +func TestWantedListExtendedUnmarshalJSON(t *testing.T) { + // Slice format + w := &WantedListExtended{} + data := []byte(`["type1", "type2"]`) + if err := json.Unmarshal(data, w); err != nil { + t.Errorf("UnmarshalJSON slice error = %v", err) + } + if len(w.TypeSlice) != 2 { + t.Errorf("expected 2 types, got %d", len(w.TypeSlice)) + } + if w.TypeSlice[0] != "type1" || w.TypeSlice[1] != "type2" { + t.Errorf("unexpected TypeSlice: %v", w.TypeSlice) + } + + // Map format + w2 := &WantedListExtended{} + data2 := []byte(`{"key1": ["val1", "val2"], "key2": ["val3"]}`) + if err := json.Unmarshal(data2, w2); err != nil { + t.Errorf("UnmarshalJSON map error = %v", err) + } + if len(w2.TypeMap) != 2 { + t.Errorf("expected 2 keys in map, got %d", len(w2.TypeMap)) + } + + // Empty data + w3 := &WantedListExtended{} + if err := w3.UnmarshalJSON(nil); err != nil { + t.Errorf("UnmarshalJSON empty error = %v", err) + } + if err := w3.UnmarshalJSON([]byte{}); err != nil { + t.Errorf("UnmarshalJSON empty bytes error = %v", err) + } + + // Invalid JSON + w4 := &WantedListExtended{} + if err := w4.UnmarshalJSON([]byte(`{invalid}`)); err == nil { + t.Error("expected error for invalid JSON") + } +} diff --git a/lib/config_test.go b/lib/config_test.go new file mode 100644 index 00000000000..74f91199091 --- /dev/null +++ b/lib/config_test.go @@ -0,0 +1,247 @@ +package lib + +import ( + "encoding/json" + "testing" +) + +// mockInputConverter implements InputConverter for testing +type mockInputConverter struct { + typeName string + action Action + description string + inputFn func(Container) (Container, error) +} + +func (m *mockInputConverter) GetType() string { return m.typeName } +func (m *mockInputConverter) GetAction() Action { return m.action } +func (m *mockInputConverter) GetDescription() string { return m.description } +func (m *mockInputConverter) Input(c Container) (Container, error) { + if m.inputFn != nil { + return m.inputFn(c) + } + return c, nil +} + +// mockOutputConverter implements OutputConverter for testing +type mockOutputConverter struct { + typeName string + action Action + description string + outputFn func(Container) error +} + +func (m *mockOutputConverter) GetType() string { return m.typeName } +func (m *mockOutputConverter) GetAction() Action { return m.action } +func (m *mockOutputConverter) GetDescription() string { return m.description } +func (m *mockOutputConverter) Output(c Container) error { + if m.outputFn != nil { + return m.outputFn(c) + } + return nil +} + +func TestRegisterInputConfigCreator(t *testing.T) { + // Save and restore original cache + origCache := inputConfigCreatorCache + inputConfigCreatorCache = make(map[string]inputConfigCreator) + defer func() { inputConfigCreatorCache = origCache }() + + creator := func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typeName: "test-input", action: action, description: "test"}, nil + } + + // Register successfully + if err := RegisterInputConfigCreator("test-input", creator); err != nil { + t.Errorf("RegisterInputConfigCreator error = %v", err) + } + + // Duplicate registration + if err := RegisterInputConfigCreator("test-input", creator); err == nil { + t.Error("expected error for duplicate registration") + } + + // Case insensitive + if err := RegisterInputConfigCreator("TEST-INPUT", creator); err == nil { + t.Error("expected error for case-insensitive duplicate") + } +} + +func TestCreateInputConfig(t *testing.T) { + origCache := inputConfigCreatorCache + inputConfigCreatorCache = make(map[string]inputConfigCreator) + defer func() { inputConfigCreatorCache = origCache }() + + creator := func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typeName: "test-input", action: action, description: "test"}, nil + } + RegisterInputConfigCreator("test-input", creator) + + // Successful creation + ic, err := createInputConfig("test-input", ActionAdd, nil) + if err != nil { + t.Errorf("createInputConfig error = %v", err) + } + if ic.GetType() != "test-input" { + t.Errorf("expected type 'test-input', got %q", ic.GetType()) + } + + // Unknown type + _, err = createInputConfig("unknown", ActionAdd, nil) + if err == nil { + t.Error("expected error for unknown config type") + } + + // Case insensitive + ic, err = createInputConfig("TEST-INPUT", ActionAdd, nil) + if err != nil { + t.Errorf("createInputConfig case insensitive error = %v", err) + } + if ic.GetType() != "test-input" { + t.Errorf("expected type 'test-input', got %q", ic.GetType()) + } +} + +func TestRegisterOutputConfigCreator(t *testing.T) { + origCache := outputConfigCreatorCache + outputConfigCreatorCache = make(map[string]outputConfigCreator) + defer func() { outputConfigCreatorCache = origCache }() + + creator := func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typeName: "test-output", action: action, description: "test"}, nil + } + + // Register successfully + if err := RegisterOutputConfigCreator("test-output", creator); err != nil { + t.Errorf("RegisterOutputConfigCreator error = %v", err) + } + + // Duplicate registration + if err := RegisterOutputConfigCreator("test-output", creator); err == nil { + t.Error("expected error for duplicate registration") + } +} + +func TestCreateOutputConfig(t *testing.T) { + origCache := outputConfigCreatorCache + outputConfigCreatorCache = make(map[string]outputConfigCreator) + defer func() { outputConfigCreatorCache = origCache }() + + creator := func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typeName: "test-output", action: action, description: "test"}, nil + } + RegisterOutputConfigCreator("test-output", creator) + + // Successful creation + oc, err := createOutputConfig("test-output", ActionOutput, nil) + if err != nil { + t.Errorf("createOutputConfig error = %v", err) + } + if oc.GetType() != "test-output" { + t.Errorf("expected type 'test-output', got %q", oc.GetType()) + } + + // Unknown type + _, err = createOutputConfig("unknown", ActionOutput, nil) + if err == nil { + t.Error("expected error for unknown config type") + } +} + +func TestInputConvConfigUnmarshalJSON(t *testing.T) { + origCache := inputConfigCreatorCache + inputConfigCreatorCache = make(map[string]inputConfigCreator) + defer func() { inputConfigCreatorCache = origCache }() + + creator := func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typeName: "test-input", action: action, description: "test"}, nil + } + RegisterInputConfigCreator("test-input", creator) + + // Valid JSON + data := []byte(`{"type":"test-input","action":"add","args":{}}`) + icc := &inputConvConfig{} + if err := icc.UnmarshalJSON(data); err != nil { + t.Errorf("UnmarshalJSON error = %v", err) + } + if icc.iType != "test-input" { + t.Errorf("expected type 'test-input', got %q", icc.iType) + } + if icc.action != ActionAdd { + t.Errorf("expected action 'add', got %q", icc.action) + } + + // Invalid action + data2 := []byte(`{"type":"test-input","action":"invalid","args":{}}`) + icc2 := &inputConvConfig{} + if err := icc2.UnmarshalJSON(data2); err == nil { + t.Error("expected error for invalid action") + } + + // Invalid JSON + icc3 := &inputConvConfig{} + if err := icc3.UnmarshalJSON([]byte(`{invalid}`)); err == nil { + t.Error("expected error for invalid JSON") + } + + // Unknown type + data4 := []byte(`{"type":"unknown","action":"add","args":{}}`) + icc4 := &inputConvConfig{} + if err := icc4.UnmarshalJSON(data4); err == nil { + t.Error("expected error for unknown type") + } +} + +func TestOutputConvConfigUnmarshalJSON(t *testing.T) { + origCache := outputConfigCreatorCache + outputConfigCreatorCache = make(map[string]outputConfigCreator) + defer func() { outputConfigCreatorCache = origCache }() + + creator := func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typeName: "test-output", action: action, description: "test"}, nil + } + RegisterOutputConfigCreator("test-output", creator) + + // Valid JSON + data := []byte(`{"type":"test-output","action":"output","args":{}}`) + occ := &outputConvConfig{} + if err := occ.UnmarshalJSON(data); err != nil { + t.Errorf("UnmarshalJSON error = %v", err) + } + if occ.iType != "test-output" { + t.Errorf("expected type 'test-output', got %q", occ.iType) + } + if occ.action != ActionOutput { + t.Errorf("expected action 'output', got %q", occ.action) + } + + // Default action (empty action defaults to "output") + data2 := []byte(`{"type":"test-output","args":{}}`) + occ2 := &outputConvConfig{} + if err := occ2.UnmarshalJSON(data2); err != nil { + t.Errorf("UnmarshalJSON default action error = %v", err) + } + if occ2.action != ActionOutput { + t.Errorf("expected default action 'output', got %q", occ2.action) + } + + // Invalid action + data3 := []byte(`{"type":"test-output","action":"invalid","args":{}}`) + occ3 := &outputConvConfig{} + if err := occ3.UnmarshalJSON(data3); err == nil { + t.Error("expected error for invalid action") + } + + // Invalid JSON + occ4 := &outputConvConfig{} + if err := occ4.UnmarshalJSON([]byte(`{invalid}`)); err == nil { + t.Error("expected error for invalid JSON") + } + + // Unknown type + data5 := []byte(`{"type":"unknown","action":"output","args":{}}`) + occ5 := &outputConvConfig{} + if err := occ5.UnmarshalJSON(data5); err == nil { + t.Error("expected error for unknown type") + } +} diff --git a/lib/container_test.go b/lib/container_test.go new file mode 100644 index 00000000000..8787722a572 --- /dev/null +++ b/lib/container_test.go @@ -0,0 +1,1174 @@ +package lib + +import ( + "sort" + "testing" +) + +func TestNewContainer(t *testing.T) { + c := NewContainer() + if c == nil { + t.Fatal("NewContainer returned nil") + } + if c.Len() != 0 { + t.Errorf("expected Len() = 0, got %d", c.Len()) + } +} + +func TestContainerGetEntry(t *testing.T) { + c := NewContainer() + + // Entry not found + _, found := c.GetEntry("US") + if found { + t.Error("expected entry not found") + } + + // Add entry and retrieve + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + val, found := c.GetEntry("US") + if !found { + t.Error("expected entry to be found") + } + if val.GetName() != "US" { + t.Errorf("expected name US, got %q", val.GetName()) + } + + // Case insensitive with spaces + val, found = c.GetEntry(" us ") + if !found { + t.Error("expected entry to be found with spaces and lowercase") + } + if val.GetName() != "US" { + t.Errorf("expected name US, got %q", val.GetName()) + } +} + +func TestContainerGetEntryInvalid(t *testing.T) { + c := &container{entries: nil} + _, found := c.GetEntry("US") + if found { + t.Error("expected entry not found on invalid container") + } +} + +func TestContainerLen(t *testing.T) { + c := NewContainer() + if c.Len() != 0 { + t.Errorf("expected Len() = 0, got %d", c.Len()) + } + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + if c.Len() != 1 { + t.Errorf("expected Len() = 1, got %d", c.Len()) + } +} + +func TestContainerLenInvalid(t *testing.T) { + c := &container{entries: nil} + if c.Len() != 0 { + t.Errorf("expected Len() = 0 on invalid container, got %d", c.Len()) + } +} + +func TestContainerLoop(t *testing.T) { + c := NewContainer() + + entry1 := NewEntry("us") + if err := entry1.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + entry2 := NewEntry("cn") + if err := entry2.AddPrefix("2.0.0.0/24"); err != nil { + t.Fatal(err) + } + + if err := c.Add(entry1); err != nil { + t.Fatal(err) + } + if err := c.Add(entry2); err != nil { + t.Fatal(err) + } + + names := make([]string, 0) + for entry := range c.Loop() { + names = append(names, entry.GetName()) + } + + sort.Strings(names) + if len(names) != 2 { + t.Errorf("expected 2 entries, got %d", len(names)) + } + if names[0] != "CN" || names[1] != "US" { + t.Errorf("unexpected names: %v", names) + } +} + +func TestContainerAdd(t *testing.T) { + c := NewContainer() + + // Add new entry + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Add to existing entry (merge) + entry2 := NewEntry("us") + if err := entry2.AddPrefix("2.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry2.AddPrefix("2001:db9::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry2); err != nil { + t.Fatal(err) + } + + if c.Len() != 1 { + t.Errorf("expected 1 entry after merge, got %d", c.Len()) + } + + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + // 1.0.0.0/24, 2.0.0.0/24 are separate, 2001:db8::/32 and 2001:db9::/32 are adjacent + // so they get merged into 2001:db8::/31, resulting in 3 prefixes total + if len(prefixes) != 3 { + t.Errorf("expected 3 prefixes after merge, got %d", len(prefixes)) + } +} + +func TestContainerAddIgnoreIPv4(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + + // Add with IgnoreIPv4 - should only have IPv6 + if err := c.Add(entry, IgnoreIPv4); err != nil { + t.Fatal(err) + } + + val, _ := c.GetEntry("US") + if val.hasIPv4Builder() { + t.Error("entry should not have IPv4 builder when ignoring IPv4") + } +} + +func TestContainerAddIgnoreIPv6(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + + // Add with IgnoreIPv6 - should only have IPv4 + if err := c.Add(entry, IgnoreIPv6); err != nil { + t.Fatal(err) + } + + val, _ := c.GetEntry("US") + if val.hasIPv6Builder() { + t.Error("entry should not have IPv6 builder when ignoring IPv6") + } +} + +func TestContainerAddMergeIgnoreIPv4(t *testing.T) { + c := NewContainer() + + // First add normally + entry1 := NewEntry("us") + if err := entry1.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry1); err != nil { + t.Fatal(err) + } + + // Add with IgnoreIPv4 to existing entry + entry2 := NewEntry("us") + if err := entry2.AddPrefix("2.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry2.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry2, IgnoreIPv4); err != nil { + t.Fatal(err) + } + + // Verify: original IPv4 should remain, IPv6 should be merged + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + if p.String() != "1.0.0.0/24" { + t.Errorf("expected original IPv4 prefix 1.0.0.0/24, got %s", p.String()) + } + } else { + gotIPv6++ + if p.String() != "2001:db8::/32" { + t.Errorf("expected merged IPv6 prefix 2001:db8::/32, got %s", p.String()) + } + } + } + if gotIPv4 != 1 { + t.Errorf("expected 1 IPv4 prefix, got %d", gotIPv4) + } + if gotIPv6 != 1 { + t.Errorf("expected 1 IPv6 prefix, got %d", gotIPv6) + } +} + +func TestContainerAddMergeIgnoreIPv6(t *testing.T) { + c := NewContainer() + + // First add normally + entry1 := NewEntry("us") + if err := entry1.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry1); err != nil { + t.Fatal(err) + } + + // Add with IgnoreIPv6 to existing entry + entry2 := NewEntry("us") + if err := entry2.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry2.AddPrefix("2001:db9::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry2, IgnoreIPv6); err != nil { + t.Fatal(err) + } + + // Verify: IPv4 should be merged, original IPv6 should remain + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + if p.String() != "1.0.0.0/24" { + t.Errorf("expected merged IPv4 prefix 1.0.0.0/24, got %s", p.String()) + } + } else { + gotIPv6++ + if p.String() != "2001:db8::/32" { + t.Errorf("expected original IPv6 prefix 2001:db8::/32, got %s", p.String()) + } + } + } + if gotIPv4 != 1 { + t.Errorf("expected 1 IPv4 prefix, got %d", gotIPv4) + } + if gotIPv6 != 1 { + t.Errorf("expected 1 IPv6 prefix, got %d", gotIPv6) + } +} + +func TestContainerAddNilOption(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + // Pass nil option + if err := c.Add(entry, nil); err != nil { + t.Fatal(err) + } +} + +func TestContainerRemoveCaseRemovePrefix(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Remove prefix + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := removeEntry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix); err != nil { + t.Errorf("Remove error = %v", err) + } + + // Verify: only 2.0.0.0/24 should remain + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + if p.String() != "2.0.0.0/24" { + t.Errorf("expected remaining IPv4 prefix 2.0.0.0/24, got %s", p.String()) + } + } else { + gotIPv6++ + } + } + if gotIPv4 != 1 { + t.Errorf("expected 1 IPv4 prefix, got %d", gotIPv4) + } + if gotIPv6 != 0 { + t.Errorf("expected 0 IPv6 prefixes, got %d", gotIPv6) + } +} + +func TestContainerRemoveCaseRemovePrefixIgnoreIPv4(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := removeEntry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix, IgnoreIPv4); err != nil { + t.Errorf("Remove with IgnoreIPv4 error = %v", err) + } + + // Verify: IPv4 should be untouched (1.0.0.0/24 remains), IPv6 should be removed + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + if p.String() != "1.0.0.0/24" { + t.Errorf("expected remaining IPv4 prefix 1.0.0.0/24, got %s", p.String()) + } + } else { + gotIPv6++ + } + } + if gotIPv4 != 1 { + t.Errorf("expected 1 IPv4 prefix, got %d", gotIPv4) + } + if gotIPv6 != 0 { + t.Errorf("expected 0 IPv6 prefixes, got %d", gotIPv6) + } +} + +func TestContainerRemoveCaseRemovePrefixIgnoreIPv6(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := removeEntry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix, IgnoreIPv6); err != nil { + t.Errorf("Remove with IgnoreIPv6 error = %v", err) + } + + // Verify: IPv4 should be removed, IPv6 should be untouched (2001:db8::/32 remains) + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + } else { + gotIPv6++ + if p.String() != "2001:db8::/32" { + t.Errorf("expected remaining IPv6 prefix 2001:db8::/32, got %s", p.String()) + } + } + } + if gotIPv4 != 0 { + t.Errorf("expected 0 IPv4 prefixes, got %d", gotIPv4) + } + if gotIPv6 != 1 { + t.Errorf("expected 1 IPv6 prefix, got %d", gotIPv6) + } +} + +func TestContainerRemoveCaseRemoveEntry(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Remove entire entry + removeEntry := NewEntry("us") + if err := c.Remove(removeEntry, CaseRemoveEntry); err != nil { + t.Errorf("Remove error = %v", err) + } + + if c.Len() != 0 { + t.Errorf("expected 0 entries after remove, got %d", c.Len()) + } +} + +func TestContainerRemoveCaseRemoveEntryIgnoreIPv4(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + removeEntry := NewEntry("us") + if err := c.Remove(removeEntry, CaseRemoveEntry, IgnoreIPv4); err != nil { + t.Errorf("Remove error = %v", err) + } + + // Entry should still exist (only ipv6 builder was set to nil) + val, found := c.GetEntry("US") + if !found { + t.Error("entry should still exist after CaseRemoveEntry with IgnoreIPv4") + } + if val.hasIPv6Builder() { + t.Error("ipv6Builder should be nil after CaseRemoveEntry with IgnoreIPv4") + } +} + +func TestContainerRemoveCaseRemoveEntryIgnoreIPv6(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + removeEntry := NewEntry("us") + if err := c.Remove(removeEntry, CaseRemoveEntry, IgnoreIPv6); err != nil { + t.Errorf("Remove error = %v", err) + } + + val, found := c.GetEntry("US") + if !found { + t.Error("entry should still exist after CaseRemoveEntry with IgnoreIPv6") + } + if val.hasIPv4Builder() { + t.Error("ipv4Builder should be nil after CaseRemoveEntry with IgnoreIPv6") + } +} + +func TestContainerRemoveNotFound(t *testing.T) { + c := NewContainer() + removeEntry := NewEntry("us") + err := c.Remove(removeEntry, CaseRemovePrefix) + if err == nil { + t.Error("Remove on non-existent entry should return error") + } +} + +func TestContainerRemoveUnknownCase(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + removeEntry := NewEntry("us") + err := c.Remove(removeEntry, CaseRemove(99)) + if err == nil { + t.Error("Remove with unknown case should return error") + } +} + +func TestContainerRemoveNilOption(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix, nil); err != nil { + t.Errorf("Remove with nil option error = %v", err) + } +} + +func TestContainerLookupIP(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + entry2 := NewEntry("cn") + if err := entry2.AddPrefix("2.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry2.AddPrefix("2001:db9::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry2); err != nil { + t.Fatal(err) + } + + // Lookup IPv4 + result, found, err := c.Lookup("1.0.0.1") + if err != nil { + t.Fatal(err) + } + if !found { + t.Error("expected to find 1.0.0.1") + } + if len(result) != 1 || result[0] != "US" { + t.Errorf("expected [US], got %v", result) + } + + // Lookup IPv6 + result, found, err = c.Lookup("2001:db8::1") + if err != nil { + t.Fatal(err) + } + if !found { + t.Error("expected to find 2001:db8::1") + } + if len(result) != 1 || result[0] != "US" { + t.Errorf("expected [US], got %v", result) + } + + // Lookup not found + _, found, err = c.Lookup("3.0.0.1") + if err != nil { + t.Fatal(err) + } + if found { + t.Error("expected not to find 3.0.0.1") + } + + // Lookup invalid IP + _, _, err = c.Lookup("invalid") + if err == nil { + t.Error("expected error for invalid IP") + } +} + +func TestContainerLookupCIDR(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/16"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Lookup IPv4 CIDR + result, found, err := c.Lookup("1.0.0.0/24") + if err != nil { + t.Fatal(err) + } + if !found { + t.Error("expected to find 1.0.0.0/24") + } + if len(result) != 1 || result[0] != "US" { + t.Errorf("expected [US], got %v", result) + } + + // Lookup IPv6 CIDR + result, found, err = c.Lookup("2001:db8::/48") + if err != nil { + t.Fatal(err) + } + if !found { + t.Error("expected to find 2001:db8::/48") + } + if len(result) != 1 || result[0] != "US" { + t.Errorf("expected [US], got %v", result) + } + + // Lookup not found CIDR + _, found, err = c.Lookup("3.0.0.0/24") + if err != nil { + t.Fatal(err) + } + if found { + t.Error("expected not to find 3.0.0.0/24") + } + + // Lookup invalid CIDR + _, _, err = c.Lookup("invalid/24") + if err == nil { + t.Error("expected error for invalid CIDR") + } +} + +func TestContainerLookupWithSearchList(t *testing.T) { + c := NewContainer() + + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + entry2 := NewEntry("cn") + if err := entry2.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry2); err != nil { + t.Fatal(err) + } + + // Search with specific list + result, found, err := c.Lookup("1.0.0.1", "us") + if err != nil { + t.Fatal(err) + } + if !found { + t.Error("expected to find 1.0.0.1 in US") + } + if len(result) != 1 || result[0] != "US" { + t.Errorf("expected [US], got %v", result) + } + + // Search with empty string in list (should be skipped) + result, found, err = c.Lookup("1.0.0.1", " ") + if err != nil { + t.Fatal(err) + } + if !found { + t.Error("expected to find 1.0.0.1 with empty search list") + } + if len(result) != 2 { + t.Errorf("expected 2 results, got %d", len(result)) + } + + // Search with non-existent entry + _, found, err = c.Lookup("1.0.0.1", "jp") + if err != nil { + t.Fatal(err) + } + if found { + t.Error("expected not to find 1.0.0.1 in JP") + } +} + +func TestContainerRemoveCaseRemovePrefixNoBuilderOnExisting(t *testing.T) { + c := NewContainer() + + // Add entry with only IPv4 + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Try to remove IPv6 from entry that only has IPv4 - should initialize builder + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix); err != nil { + t.Errorf("Remove error = %v", err) + } + + // Verify: IPv4 1.0.0.0/24 should remain, no IPv6 prefixes + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + if p.String() != "1.0.0.0/24" { + t.Errorf("expected remaining IPv4 prefix 1.0.0.0/24, got %s", p.String()) + } + } else { + gotIPv6++ + } + } + if gotIPv4 != 1 { + t.Errorf("expected 1 IPv4 prefix, got %d", gotIPv4) + } + if gotIPv6 != 0 { + t.Errorf("expected 0 IPv6 prefixes, got %d", gotIPv6) + } +} + +func TestContainerAddMergeExistingWithoutBuilders(t *testing.T) { + c := NewContainer() + + // Add entry with only IPv4 + entry1 := NewEntry("us") + if err := entry1.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry1); err != nil { + t.Fatal(err) + } + + // Add entry with only IPv6 to merge + entry2 := NewEntry("us") + if err := entry2.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry2); err != nil { + t.Fatal(err) + } + + // Entry should now have both + val, _ := c.GetEntry("US") + if !val.hasIPv4Builder() { + t.Error("entry should have IPv4 builder") + } + if !val.hasIPv6Builder() { + t.Error("entry should have IPv6 builder") + } +} + +func TestContainerAddMergeExistingOnlyIPv6DefaultIgnore(t *testing.T) { + c := NewContainer() + + // First entry has only IPv6 + entry1 := NewEntry("us") + if err := entry1.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry1); err != nil { + t.Fatal(err) + } + + // Add new entry with both IPv4 and IPv6 - existing lacks IPv4 builder + entry2 := NewEntry("us") + if err := entry2.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry2.AddPrefix("2001:db9::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry2); err != nil { + t.Fatal(err) + } + + // val should now have both builders + val, _ := c.GetEntry("US") + if !val.hasIPv4Builder() { + t.Error("entry should have IPv4 builder after merge") + } + if !val.hasIPv6Builder() { + t.Error("entry should have IPv6 builder after merge") + } +} + +func TestContainerRemoveCaseRemovePrefixNoIPv6BuilderIgnoreIPv4(t *testing.T) { + c := NewContainer() + + // Add entry with only IPv4 + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Remove with IgnoreIPv4 from entry that only has IPv4 (no IPv6 builder on val) + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix, IgnoreIPv4); err != nil { + t.Errorf("Remove error = %v", err) + } + + // Verify: IPv4 1.0.0.0/24 should remain (ignored), no IPv6 prefixes + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + if p.String() != "1.0.0.0/24" { + t.Errorf("expected remaining IPv4 prefix 1.0.0.0/24, got %s", p.String()) + } + } else { + gotIPv6++ + } + } + if gotIPv4 != 1 { + t.Errorf("expected 1 IPv4 prefix, got %d", gotIPv4) + } + if gotIPv6 != 0 { + t.Errorf("expected 0 IPv6 prefixes, got %d", gotIPv6) + } +} + +func TestContainerRemoveCaseRemovePrefixNoIPv4BuilderIgnoreIPv6(t *testing.T) { + c := NewContainer() + + // Add entry with only IPv6 + entry := NewEntry("us") + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Remove with IgnoreIPv6 from entry that only has IPv6 (no IPv4 builder on val) + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix, IgnoreIPv6); err != nil { + t.Errorf("Remove error = %v", err) + } + + // Verify: no IPv4 prefixes, IPv6 2001:db8::/32 should remain (ignored) + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + } else { + gotIPv6++ + if p.String() != "2001:db8::/32" { + t.Errorf("expected remaining IPv6 prefix 2001:db8::/32, got %s", p.String()) + } + } + } + if gotIPv4 != 0 { + t.Errorf("expected 0 IPv4 prefixes, got %d", gotIPv4) + } + if gotIPv6 != 1 { + t.Errorf("expected 1 IPv6 prefix, got %d", gotIPv6) + } +} + +func TestContainerRemoveCaseRemovePrefixNoBuilderDefault(t *testing.T) { + c := NewContainer() + + // Add entry with only IPv6 (no IPv4 builder on val) + entry := NewEntry("us") + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Remove default (no ignore) from entry that lacks IPv4 builder + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := removeEntry.AddPrefix("2001:db8::/48"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix); err != nil { + t.Errorf("Remove error = %v", err) + } + + // Verify: no IPv4 prefixes, IPv6 should have 2001:db8::/48 removed from 2001:db8::/32 + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + } else { + gotIPv6++ + } + } + if gotIPv4 != 0 { + t.Errorf("expected 0 IPv4 prefixes, got %d", gotIPv4) + } + if gotIPv6 == 0 { + t.Error("expected remaining IPv6 prefixes after removing /48 from /32") + } +} + +func TestContainerRemoveCaseRemovePrefixOnlyIPv4NoIPv6Builder(t *testing.T) { + c := NewContainer() + + // Add entry with only IPv4 (no IPv6 builder on val) + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Remove default from entry that lacks IPv6 builder + removeEntry := NewEntry("us") + if err := removeEntry.AddPrefix("1.0.0.0/25"); err != nil { + t.Fatal(err) + } + if err := removeEntry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Remove(removeEntry, CaseRemovePrefix); err != nil { + t.Errorf("Remove error = %v", err) + } + + // Verify: IPv4 should have 1.0.0.0/25 removed from 1.0.0.0/24, leaving 1.0.0.128/25; no IPv6 prefixes + val, found := c.GetEntry("US") + if !found { + t.Fatal("entry US not found") + } + prefixes, err := val.MarshalPrefix() + if err != nil { + t.Fatal(err) + } + gotIPv4, gotIPv6 := 0, 0 + for _, p := range prefixes { + if p.Addr().Is4() { + gotIPv4++ + if p.String() != "1.0.0.128/25" { + t.Errorf("expected remaining IPv4 prefix 1.0.0.128/25, got %s", p.String()) + } + } else { + gotIPv6++ + } + } + if gotIPv4 != 1 { + t.Errorf("expected 1 IPv4 prefix, got %d", gotIPv4) + } + if gotIPv6 != 0 { + t.Errorf("expected 0 IPv6 prefixes, got %d", gotIPv6) + } +} + +func TestContainerLookupIPv6NotFoundInEntry(t *testing.T) { + c := NewContainer() + + // Add entry with both IPv4 and IPv6 so lookup doesn't fail + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Lookup IPv6 not found in any entry + _, found, err := c.Lookup("fd00::1") + if err != nil { + t.Fatal(err) + } + if found { + t.Error("expected not to find fd00::1") + } +} + +func TestContainerLookupIPv4ErrorEntryMissingIPSet(t *testing.T) { + c := NewContainer() + + // Add entry with only IPv6 (no IPv4 data) + entry := NewEntry("us") + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Lookup IPv4 address - will fail because US entry has no IPv4 set + _, _, err := c.Lookup("1.0.0.1") + if err == nil { + t.Error("expected error when looking up IPv4 in entry with no IPv4 data") + } +} + +func TestContainerLookupIPv6ErrorEntryMissingIPSet(t *testing.T) { + c := NewContainer() + + // Add entry with only IPv4 (no IPv6 data) + entry := NewEntry("us") + if err := entry.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := c.Add(entry); err != nil { + t.Fatal(err) + } + + // Lookup IPv6 address - will fail because US entry has no IPv6 set + _, _, err := c.Lookup("2001:db8::1") + if err == nil { + t.Error("expected error when looking up IPv6 in entry with no IPv6 data") + } + + // Similarly for CIDR lookup + _, _, err = c.Lookup("2001:db8::/32") + if err == nil { + t.Error("expected error when looking up IPv6 CIDR in entry with no IPv6 data") + } +} diff --git a/lib/converter_test.go b/lib/converter_test.go new file mode 100644 index 00000000000..a2ba102b5e2 --- /dev/null +++ b/lib/converter_test.go @@ -0,0 +1,99 @@ +package lib + +import ( + "bytes" + "os" + "testing" +) + +func TestRegisterInputConverter(t *testing.T) { + origMap := inputConverterMap + inputConverterMap = make(map[string]InputConverter) + defer func() { inputConverterMap = origMap }() + + mock := &mockInputConverter{typeName: "test-ic", action: ActionAdd, description: "Test Input"} + + // Register successfully + if err := RegisterInputConverter("test-ic", mock); err != nil { + t.Errorf("RegisterInputConverter error = %v", err) + } + + // Duplicate registration + if err := RegisterInputConverter("test-ic", mock); err != ErrDuplicatedConverter { + t.Errorf("expected ErrDuplicatedConverter, got %v", err) + } +} + +func TestRegisterOutputConverter(t *testing.T) { + origMap := outputConverterMap + outputConverterMap = make(map[string]OutputConverter) + defer func() { outputConverterMap = origMap }() + + mock := &mockOutputConverter{typeName: "test-oc", action: ActionOutput, description: "Test Output"} + + // Register successfully + if err := RegisterOutputConverter("test-oc", mock); err != nil { + t.Errorf("RegisterOutputConverter error = %v", err) + } + + // Duplicate registration + if err := RegisterOutputConverter("test-oc", mock); err != ErrDuplicatedConverter { + t.Errorf("expected ErrDuplicatedConverter, got %v", err) + } +} + +func TestListInputConverter(t *testing.T) { + origMap := inputConverterMap + inputConverterMap = make(map[string]InputConverter) + defer func() { inputConverterMap = origMap }() + + mock := &mockInputConverter{typeName: "test-ic", action: ActionAdd, description: "Test Input"} + RegisterInputConverter("test-ic", mock) + + // Capture stdout + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + ListInputConverter() + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + buf.ReadFrom(r) + output := buf.String() + + expected := "All available input formats:\n - test-ic (Test Input)\n" + if output != expected { + t.Errorf("ListInputConverter output = %q, want %q", output, expected) + } +} + +func TestListOutputConverter(t *testing.T) { + origMap := outputConverterMap + outputConverterMap = make(map[string]OutputConverter) + defer func() { outputConverterMap = origMap }() + + mock := &mockOutputConverter{typeName: "test-oc", action: ActionOutput, description: "Test Output"} + RegisterOutputConverter("test-oc", mock) + + // Capture stdout + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + ListOutputConverter() + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + buf.ReadFrom(r) + output := buf.String() + + expected := "All available output formats:\n - test-oc (Test Output)\n" + if output != expected { + t.Errorf("ListOutputConverter output = %q, want %q", output, expected) + } +} diff --git a/lib/entry_test.go b/lib/entry_test.go new file mode 100644 index 00000000000..881418b386f --- /dev/null +++ b/lib/entry_test.go @@ -0,0 +1,689 @@ +package lib + +import ( + "net" + "net/netip" + "testing" +) + +func TestNewEntry(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"us", "US"}, + {" cn ", "CN"}, + {"JP", "JP"}, + {"", ""}, + } + for _, tt := range tests { + e := NewEntry(tt.input) + if e.GetName() != tt.expected { + t.Errorf("NewEntry(%q).GetName() = %q, want %q", tt.input, e.GetName(), tt.expected) + } + } +} + +func TestEntryHasBuilderAndSet(t *testing.T) { + e := NewEntry("test") + if e.hasIPv4Builder() { + t.Error("new entry should not have ipv4 builder") + } + if e.hasIPv6Builder() { + t.Error("new entry should not have ipv6 builder") + } + if e.hasIPv4Set() { + t.Error("new entry should not have ipv4 set") + } + if e.hasIPv6Set() { + t.Error("new entry should not have ipv6 set") + } +} + +func TestEntryAddPrefixString(t *testing.T) { + tests := []struct { + name string + cidr string + wantErr bool + }{ + {"ipv4 cidr", "1.0.0.0/24", false}, + {"ipv6 cidr", "2001:db8::/32", false}, + {"ipv4 address", "8.8.8.8", false}, + {"ipv6 address", "2001:db8::1", false}, + {"comment #", "# comment", true}, + {"comment //", "// comment", true}, + {"comment /*", "/* comment */", true}, + {"empty string", "", true}, + {"invalid string", "not-an-ip", true}, + {"invalid cidr", "999.999.999.999/24", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewEntry("test") + err := e.AddPrefix(tt.cidr) + if (err != nil) != tt.wantErr { + t.Errorf("AddPrefix(%q) error = %v, wantErr %v", tt.cidr, err, tt.wantErr) + } + }) + } +} + +func TestEntryAddPrefixNetIP(t *testing.T) { + e := NewEntry("test") + + // net.IP IPv4 + ipv4 := net.ParseIP("1.2.3.4") + if err := e.AddPrefix(ipv4); err != nil { + t.Errorf("AddPrefix(net.IP IPv4) error = %v", err) + } + + // net.IP IPv6 + ipv6 := net.ParseIP("2001:db8::1") + if err := e.AddPrefix(ipv6); err != nil { + t.Errorf("AddPrefix(net.IP IPv6) error = %v", err) + } + + // Invalid net.IP + invalidIP := net.IP{} + if err := e.AddPrefix(invalidIP); err == nil { + t.Error("AddPrefix(invalid net.IP) should return error") + } +} + +func TestEntryAddPrefixNetIPNet(t *testing.T) { + e := NewEntry("test") + + // IPv4 net + _, ipNet4, _ := net.ParseCIDR("10.0.0.0/8") + if err := e.AddPrefix(ipNet4); err != nil { + t.Errorf("AddPrefix(*net.IPNet IPv4) error = %v", err) + } + + // IPv6 net + _, ipNet6, _ := net.ParseCIDR("2001:db8::/32") + if err := e.AddPrefix(ipNet6); err != nil { + t.Errorf("AddPrefix(*net.IPNet IPv6) error = %v", err) + } +} + +func TestEntryAddPrefixNetipAddr(t *testing.T) { + e := NewEntry("test") + + // netip.Addr IPv4 + addr4 := netip.MustParseAddr("1.2.3.4") + if err := e.AddPrefix(addr4); err != nil { + t.Errorf("AddPrefix(netip.Addr IPv4) error = %v", err) + } + + // netip.Addr IPv6 + addr6 := netip.MustParseAddr("2001:db8::1") + if err := e.AddPrefix(addr6); err != nil { + t.Errorf("AddPrefix(netip.Addr IPv6) error = %v", err) + } + + // *netip.Addr IPv4 + a4 := netip.MustParseAddr("5.6.7.8") + if err := e.AddPrefix(&a4); err != nil { + t.Errorf("AddPrefix(*netip.Addr IPv4) error = %v", err) + } + + // *netip.Addr IPv6 + a6 := netip.MustParseAddr("2001:db8::2") + if err := e.AddPrefix(&a6); err != nil { + t.Errorf("AddPrefix(*netip.Addr IPv6) error = %v", err) + } +} + +func TestEntryAddPrefixNetipPrefix(t *testing.T) { + e := NewEntry("test") + + // netip.Prefix IPv4 + p4 := netip.MustParsePrefix("10.0.0.0/8") + if err := e.AddPrefix(p4); err != nil { + t.Errorf("AddPrefix(netip.Prefix IPv4) error = %v", err) + } + + // netip.Prefix IPv6 + p6 := netip.MustParsePrefix("2001:db8::/32") + if err := e.AddPrefix(p6); err != nil { + t.Errorf("AddPrefix(netip.Prefix IPv6) error = %v", err) + } + + // *netip.Prefix IPv4 + pp4 := netip.MustParsePrefix("172.16.0.0/12") + if err := e.AddPrefix(&pp4); err != nil { + t.Errorf("AddPrefix(*netip.Prefix IPv4) error = %v", err) + } + + // *netip.Prefix IPv6 + pp6 := netip.MustParsePrefix("fd00::/8") + if err := e.AddPrefix(&pp6); err != nil { + t.Errorf("AddPrefix(*netip.Prefix IPv6) error = %v", err) + } +} + +func TestEntryAddPrefixInvalidType(t *testing.T) { + e := NewEntry("test") + err := e.AddPrefix(12345) + if err != ErrInvalidPrefixType { + t.Errorf("AddPrefix(int) error = %v, want ErrInvalidPrefixType", err) + } +} + +func TestEntryRemovePrefix(t *testing.T) { + e := NewEntry("test") + + // Add first + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := e.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + + // Remove + if err := e.RemovePrefix("1.0.0.0/25"); err != nil { + t.Errorf("RemovePrefix error = %v", err) + } + if err := e.RemovePrefix("2001:db8::/48"); err != nil { + t.Errorf("RemovePrefix error = %v", err) + } + + // Remove with comment line - returns error because processPrefix returns ErrCommentLine + // which is skipped, then remove() is called with nil prefix and empty IPType + if err := e.RemovePrefix("# comment"); err == nil { + t.Error("RemovePrefix comment line should return error") + } + + // Remove invalid + if err := e.RemovePrefix("invalid"); err == nil { + t.Error("RemovePrefix(invalid) should return error") + } +} + +func TestEntryRemovePrefixNoBuilder(t *testing.T) { + e := NewEntry("test") + // Remove from entry without builders should not fail + if err := e.RemovePrefix("1.0.0.0/24"); err != nil { + t.Errorf("RemovePrefix without builder error = %v", err) + } + if err := e.RemovePrefix("2001:db8::/32"); err != nil { + t.Errorf("RemovePrefix without builder error = %v", err) + } +} + +func TestEntryGetIPv4Set(t *testing.T) { + e := NewEntry("test") + + // No builder -> error + _, err := e.GetIPv4Set() + if err == nil { + t.Error("GetIPv4Set without builder should return error") + } + + // With data + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + set, err := e.GetIPv4Set() + if err != nil { + t.Errorf("GetIPv4Set error = %v", err) + } + if set == nil { + t.Error("GetIPv4Set returned nil set") + } +} + +func TestEntryGetIPv6Set(t *testing.T) { + e := NewEntry("test") + + // No builder -> error + _, err := e.GetIPv6Set() + if err == nil { + t.Error("GetIPv6Set without builder should return error") + } + + // With data + if err := e.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + set, err := e.GetIPv6Set() + if err != nil { + t.Errorf("GetIPv6Set error = %v", err) + } + if set == nil { + t.Error("GetIPv6Set returned nil set") + } +} + +func TestEntryMarshalPrefix(t *testing.T) { + e := NewEntry("test") + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := e.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + + // All prefixes + prefixes, err := e.MarshalPrefix() + if err != nil { + t.Errorf("MarshalPrefix error = %v", err) + } + if len(prefixes) != 2 { + t.Errorf("expected 2 prefixes, got %d", len(prefixes)) + } + + // Ignore IPv4 + prefixes, err = e.MarshalPrefix(IgnoreIPv4) + if err != nil { + t.Errorf("MarshalPrefix(IgnoreIPv4) error = %v", err) + } + for _, p := range prefixes { + if p.Addr().Is4() { + t.Error("should not contain IPv4 prefix when ignoring IPv4") + } + } + + // Ignore IPv6 + prefixes, err = e.MarshalPrefix(IgnoreIPv6) + if err != nil { + t.Errorf("MarshalPrefix(IgnoreIPv6) error = %v", err) + } + for _, p := range prefixes { + if p.Addr().Is6() && !p.Addr().Is4In6() { + t.Error("should not contain IPv6 prefix when ignoring IPv6") + } + } + + // With nil option + prefixes, err = e.MarshalPrefix(nil) + if err != nil { + t.Errorf("MarshalPrefix(nil) error = %v", err) + } + if len(prefixes) != 2 { + t.Errorf("expected 2 prefixes with nil option, got %d", len(prefixes)) + } +} + +func TestEntryMarshalPrefixEmpty(t *testing.T) { + e := NewEntry("test") + _, err := e.MarshalPrefix() + if err == nil { + t.Error("MarshalPrefix on empty entry should return error") + } +} + +func TestEntryMarshalIPRange(t *testing.T) { + e := NewEntry("test") + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := e.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + + // All ranges + ranges, err := e.MarshalIPRange() + if err != nil { + t.Errorf("MarshalIPRange error = %v", err) + } + if len(ranges) != 2 { + t.Errorf("expected 2 ranges, got %d", len(ranges)) + } + + // Ignore IPv4 + ranges, err = e.MarshalIPRange(IgnoreIPv4) + if err != nil { + t.Errorf("MarshalIPRange(IgnoreIPv4) error = %v", err) + } + if len(ranges) != 1 { + t.Errorf("expected 1 range, got %d", len(ranges)) + } + + // Ignore IPv6 + ranges, err = e.MarshalIPRange(IgnoreIPv6) + if err != nil { + t.Errorf("MarshalIPRange(IgnoreIPv6) error = %v", err) + } + if len(ranges) != 1 { + t.Errorf("expected 1 range, got %d", len(ranges)) + } + + // With nil option + ranges, err = e.MarshalIPRange(nil) + if err != nil { + t.Errorf("MarshalIPRange(nil) error = %v", err) + } + if len(ranges) != 2 { + t.Errorf("expected 2 ranges with nil option, got %d", len(ranges)) + } +} + +func TestEntryMarshalIPRangeEmpty(t *testing.T) { + e := NewEntry("test") + _, err := e.MarshalIPRange() + if err == nil { + t.Error("MarshalIPRange on empty entry should return error") + } +} + +func TestEntryMarshalText(t *testing.T) { + e := NewEntry("test") + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := e.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + + // All text + text, err := e.MarshalText() + if err != nil { + t.Errorf("MarshalText error = %v", err) + } + if len(text) != 2 { + t.Errorf("expected 2 text entries, got %d", len(text)) + } + + // Ignore IPv4 + text, err = e.MarshalText(IgnoreIPv4) + if err != nil { + t.Errorf("MarshalText(IgnoreIPv4) error = %v", err) + } + if len(text) != 1 { + t.Errorf("expected 1 text entry, got %d", len(text)) + } + + // Ignore IPv6 + text, err = e.MarshalText(IgnoreIPv6) + if err != nil { + t.Errorf("MarshalText(IgnoreIPv6) error = %v", err) + } + if len(text) != 1 { + t.Errorf("expected 1 text entry, got %d", len(text)) + } + + // With nil option + text, err = e.MarshalText(nil) + if err != nil { + t.Errorf("MarshalText(nil) error = %v", err) + } + if len(text) != 2 { + t.Errorf("expected 2 text entries with nil option, got %d", len(text)) + } +} + +func TestEntryMarshalTextEmpty(t *testing.T) { + e := NewEntry("test") + _, err := e.MarshalText() + if err == nil { + t.Error("MarshalText on empty entry should return error") + } +} + +func TestEntryAddInvalidIPType(t *testing.T) { + e := NewEntry("test") + prefix := netip.MustParsePrefix("1.0.0.0/24") + err := e.add(&prefix, IPType("invalid")) + if err != ErrInvalidIPType { + t.Errorf("add with invalid IPType error = %v, want ErrInvalidIPType", err) + } +} + +func TestEntryRemoveInvalidIPType(t *testing.T) { + e := NewEntry("test") + prefix := netip.MustParsePrefix("1.0.0.0/24") + err := e.remove(&prefix, IPType("invalid")) + if err != ErrInvalidIPType { + t.Errorf("remove with invalid IPType error = %v, want ErrInvalidIPType", err) + } +} + +func TestProcessPrefixStringCIDRWithComment(t *testing.T) { + e := NewEntry("test") + + // CIDR with trailing comment + prefix, ipType, err := e.processPrefix("10.0.0.0/8 # comment") + if err != nil { + t.Errorf("processPrefix CIDR with comment error = %v", err) + } + if ipType != IPv4 { + t.Errorf("expected IPv4, got %q", ipType) + } + if prefix == nil { + t.Error("prefix should not be nil") + } + + // IP with trailing comment + prefix, ipType, err = e.processPrefix("10.0.0.1 // comment") + if err != nil { + t.Errorf("processPrefix IP with comment error = %v", err) + } + if ipType != IPv4 { + t.Errorf("expected IPv4, got %q", ipType) + } + if prefix == nil { + t.Error("prefix should not be nil") + } +} + +func TestProcessPrefixIPv4MappedIPv6(t *testing.T) { + e := NewEntry("test") + + // netip.Addr - IPv4-mapped IPv6 + mapped := netip.MustParseAddr("::ffff:1.2.3.4") + prefix, ipType, err := e.processPrefix(mapped) + if err != nil { + t.Errorf("processPrefix mapped addr error = %v", err) + } + if ipType != IPv4 { + t.Errorf("expected IPv4 for mapped address, got %q", ipType) + } + if prefix == nil { + t.Error("prefix should not be nil") + } + + // *netip.Addr - IPv4-mapped IPv6 + mapped2 := netip.MustParseAddr("::ffff:5.6.7.8") + prefix, ipType, err = e.processPrefix(&mapped2) + if err != nil { + t.Errorf("processPrefix *mapped addr error = %v", err) + } + if ipType != IPv4 { + t.Errorf("expected IPv4 for *mapped address, got %q", ipType) + } + if prefix == nil { + t.Error("prefix should not be nil") + } +} + +func TestProcessPrefixNetipPrefixIs4In6(t *testing.T) { + e := NewEntry("test") + + // netip.Prefix with IPv4-in-IPv6 address + p := netip.MustParsePrefix("::ffff:10.0.0.0/104") + prefix, ipType, err := e.processPrefix(p) + if err != nil { + t.Errorf("processPrefix 4in6 prefix error = %v", err) + } + if ipType != IPv4 { + t.Errorf("expected IPv4, got %q", ipType) + } + if prefix == nil { + t.Error("prefix should not be nil") + } + + // *netip.Prefix with IPv4-in-IPv6 address + pp := netip.MustParsePrefix("::ffff:10.0.0.0/104") + prefix, ipType, err = e.processPrefix(&pp) + if err != nil { + t.Errorf("processPrefix *4in6 prefix error = %v", err) + } + if ipType != IPv4 { + t.Errorf("expected IPv4, got %q", ipType) + } + if prefix == nil { + t.Error("prefix should not be nil") + } +} + +func TestProcessPrefixNetipPrefixIs4In6InvalidBits(t *testing.T) { + e := NewEntry("test") + + // netip.Prefix with IPv4-in-IPv6 address and bits < 96 + p := netip.MustParsePrefix("::ffff:0.0.0.0/80") + _, _, err := e.processPrefix(p) + if err != ErrInvalidPrefix { + t.Errorf("expected ErrInvalidPrefix for 4in6 with bits < 96, got %v", err) + } + + // *netip.Prefix with IPv4-in-IPv6 address and bits < 96 + pp := netip.MustParsePrefix("::ffff:0.0.0.0/80") + _, _, err = e.processPrefix(&pp) + if err != ErrInvalidPrefix { + t.Errorf("expected ErrInvalidPrefix for *4in6 with bits < 96, got %v", err) + } +} + +func TestProcessPrefixStringCIDRIPv6(t *testing.T) { + e := NewEntry("test") + // Test normal IPv6 CIDR + prefix, ipType, err := e.processPrefix("fe80::/10") + if err != nil { + t.Errorf("processPrefix IPv6 CIDR error = %v", err) + } + if ipType != IPv6 { + t.Errorf("expected IPv6, got %q", ipType) + } + if prefix == nil { + t.Error("prefix should not be nil") + } +} + +func TestProcessPrefixNetIPIPv4(t *testing.T) { + e := NewEntry("test") + + // net.IP with IPv4 that is 16-byte form (IPv4-mapped IPv6) + ip := net.ParseIP("1.2.3.4") + prefix, ipType, err := e.processPrefix(ip) + if err != nil { + t.Errorf("processPrefix(net.IP v4) error = %v", err) + } + if ipType != IPv4 { + t.Errorf("expected IPv4, got %q", ipType) + } + if prefix == nil { + t.Error("prefix should not be nil") + } +} + +func TestEntryBuildIPSetIdempotent(t *testing.T) { + e := NewEntry("test") + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + if err := e.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + + // Build once + if err := e.buildIPSet(); err != nil { + t.Fatal(err) + } + set4 := e.ipv4Set + set6 := e.ipv6Set + + // Build again - should reuse + if err := e.buildIPSet(); err != nil { + t.Fatal(err) + } + if e.ipv4Set != set4 { + t.Error("ipv4Set should be reused on second build") + } + if e.ipv6Set != set6 { + t.Error("ipv6Set should be reused on second build") + } +} + +func TestEntryMarshalPrefixOnlyIPv4(t *testing.T) { + e := NewEntry("test") + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + + prefixes, err := e.MarshalPrefix() + if err != nil { + t.Errorf("MarshalPrefix error = %v", err) + } + if len(prefixes) != 1 { + t.Errorf("expected 1 prefix, got %d", len(prefixes)) + } + + // Ignore IPv4 -> should fail because only IPv4 exists + _, err = e.MarshalPrefix(IgnoreIPv4) + if err == nil { + t.Error("MarshalPrefix(IgnoreIPv4) should error when only IPv4 data exists") + } +} + +func TestEntryMarshalPrefixOnlyIPv6(t *testing.T) { + e := NewEntry("test") + if err := e.AddPrefix("2001:db8::/32"); err != nil { + t.Fatal(err) + } + + prefixes, err := e.MarshalPrefix() + if err != nil { + t.Errorf("MarshalPrefix error = %v", err) + } + if len(prefixes) != 1 { + t.Errorf("expected 1 prefix, got %d", len(prefixes)) + } + + // Ignore IPv6 -> should fail because only IPv6 exists + _, err = e.MarshalPrefix(IgnoreIPv6) + if err == nil { + t.Error("MarshalPrefix(IgnoreIPv6) should error when only IPv6 data exists") + } +} + +func TestEntryMarshalIPRangeOnlyIPv4(t *testing.T) { + e := NewEntry("test") + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + + // Ignore IPv4 -> should fail + _, err := e.MarshalIPRange(IgnoreIPv4) + if err == nil { + t.Error("MarshalIPRange(IgnoreIPv4) should error when only IPv4 data exists") + } +} + +func TestEntryMarshalTextOnlyIPv4(t *testing.T) { + e := NewEntry("test") + if err := e.AddPrefix("1.0.0.0/24"); err != nil { + t.Fatal(err) + } + + // Ignore IPv4 -> should fail + _, err := e.MarshalText(IgnoreIPv4) + if err == nil { + t.Error("MarshalText(IgnoreIPv4) should error when only IPv4 data exists") + } +} + +func TestProcessPrefixInvalidNetIPNet(t *testing.T) { + e := NewEntry("test") + // Create an invalid *net.IPNet + invalidIPNet := &net.IPNet{ + IP: nil, + Mask: nil, + } + _, _, err := e.processPrefix(invalidIPNet) + if err == nil { + t.Error("processPrefix with invalid *net.IPNet should return error") + } +} diff --git a/lib/error_test.go b/lib/error_test.go new file mode 100644 index 00000000000..1de19d1a156 --- /dev/null +++ b/lib/error_test.go @@ -0,0 +1,36 @@ +package lib + +import ( + "testing" +) + +func TestErrorMessages(t *testing.T) { + tests := []struct { + name string + err error + expected string + }{ + {"ErrDuplicatedConverter", ErrDuplicatedConverter, "duplicated converter"}, + {"ErrUnknownAction", ErrUnknownAction, "unknown action"}, + {"ErrNotSupportedFormat", ErrNotSupportedFormat, "not supported format"}, + {"ErrInvalidIPType", ErrInvalidIPType, "invalid IP type"}, + {"ErrInvalidIP", ErrInvalidIP, "invalid IP address"}, + {"ErrInvalidIPLength", ErrInvalidIPLength, "invalid IP address length"}, + {"ErrInvalidIPNet", ErrInvalidIPNet, "invalid IPNet address"}, + {"ErrInvalidCIDR", ErrInvalidCIDR, "invalid CIDR"}, + {"ErrInvalidPrefix", ErrInvalidPrefix, "invalid prefix"}, + {"ErrInvalidPrefixType", ErrInvalidPrefixType, "invalid prefix type"}, + {"ErrCommentLine", ErrCommentLine, "comment line"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err == nil { + t.Fatal("error should not be nil") + } + if tt.err.Error() != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, tt.err.Error()) + } + }) + } +} diff --git a/lib/instance_test.go b/lib/instance_test.go new file mode 100644 index 00000000000..d60982ba551 --- /dev/null +++ b/lib/instance_test.go @@ -0,0 +1,403 @@ +package lib + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestNewInstance(t *testing.T) { + inst, err := NewInstance() + if err != nil { + t.Fatalf("NewInstance error = %v", err) + } + if inst == nil { + t.Fatal("NewInstance returned nil") + } +} + +func TestInstanceAddInputOutput(t *testing.T) { + inst, _ := NewInstance() + + ic := &mockInputConverter{typeName: "test", action: ActionAdd, description: "test"} + oc := &mockOutputConverter{typeName: "test", action: ActionOutput, description: "test"} + + inst.AddInput(ic) + inst.AddOutput(oc) + + // Verify through Run (it should work since both exist) + err := inst.Run() + if err != nil { + t.Errorf("Run error = %v", err) + } +} + +func TestInstanceResetInputOutput(t *testing.T) { + inst, _ := NewInstance() + + ic := &mockInputConverter{typeName: "test", action: ActionAdd, description: "test"} + oc := &mockOutputConverter{typeName: "test", action: ActionOutput, description: "test"} + + inst.AddInput(ic) + inst.AddOutput(oc) + + inst.ResetInput() + inst.ResetOutput() + + // Should fail because both are now empty + err := inst.Run() + if err == nil { + t.Error("expected error after reset") + } +} + +func TestInstanceRunNoInput(t *testing.T) { + inst, _ := NewInstance() + oc := &mockOutputConverter{typeName: "test", action: ActionOutput, description: "test"} + inst.AddOutput(oc) + + err := inst.Run() + if err == nil { + t.Error("expected error when no input") + } +} + +func TestInstanceRunNoOutput(t *testing.T) { + inst, _ := NewInstance() + ic := &mockInputConverter{typeName: "test", action: ActionAdd, description: "test"} + inst.AddInput(ic) + + err := inst.Run() + if err == nil { + t.Error("expected error when no output") + } +} + +func TestInstanceRunInputError(t *testing.T) { + inst, _ := NewInstance() + + ic := &mockInputConverter{ + typeName: "test", action: ActionAdd, description: "test", + inputFn: func(c Container) (Container, error) { + return nil, errors.New("input error") + }, + } + oc := &mockOutputConverter{typeName: "test", action: ActionOutput, description: "test"} + + inst.AddInput(ic) + inst.AddOutput(oc) + + err := inst.Run() + if err == nil { + t.Error("expected input error") + } + if err.Error() != "input error" { + t.Errorf("unexpected error: %v", err) + } +} + +func TestInstanceRunOutputError(t *testing.T) { + inst, _ := NewInstance() + + ic := &mockInputConverter{typeName: "test", action: ActionAdd, description: "test"} + oc := &mockOutputConverter{ + typeName: "test", action: ActionOutput, description: "test", + outputFn: func(c Container) error { + return errors.New("output error") + }, + } + + inst.AddInput(ic) + inst.AddOutput(oc) + + err := inst.Run() + if err == nil { + t.Error("expected output error") + } + if err.Error() != "output error" { + t.Errorf("unexpected error: %v", err) + } +} + +func TestInstanceRunInput(t *testing.T) { + inst, _ := NewInstance() + + ic := &mockInputConverter{typeName: "test", action: ActionAdd, description: "test"} + inst.AddInput(ic) + + container := NewContainer() + if err := inst.RunInput(container); err != nil { + t.Errorf("RunInput error = %v", err) + } +} + +func TestInstanceRunOutput(t *testing.T) { + inst, _ := NewInstance() + + oc := &mockOutputConverter{typeName: "test", action: ActionOutput, description: "test"} + inst.AddOutput(oc) + + container := NewContainer() + if err := inst.RunOutput(container); err != nil { + t.Errorf("RunOutput error = %v", err) + } +} + +func TestInstanceInitConfigFromBytes(t *testing.T) { + origInputCache := inputConfigCreatorCache + origOutputCache := outputConfigCreatorCache + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) + defer func() { + inputConfigCreatorCache = origInputCache + outputConfigCreatorCache = origOutputCache + }() + + RegisterInputConfigCreator("test-input", func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typeName: "test-input", action: action, description: "test"}, nil + }) + RegisterOutputConfigCreator("test-output", func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typeName: "test-output", action: action, description: "test"}, nil + }) + + inst, _ := NewInstance() + + configJSON := `{ + "input": [{"type": "test-input", "action": "add", "args": {}}], + "output": [{"type": "test-output", "action": "output", "args": {}}] + }` + + if err := inst.InitConfigFromBytes([]byte(configJSON)); err != nil { + t.Errorf("InitConfigFromBytes error = %v", err) + } + + if err := inst.Run(); err != nil { + t.Errorf("Run after InitConfigFromBytes error = %v", err) + } +} + +func TestInstanceInitConfigFromBytesWithComments(t *testing.T) { + origInputCache := inputConfigCreatorCache + origOutputCache := outputConfigCreatorCache + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) + defer func() { + inputConfigCreatorCache = origInputCache + outputConfigCreatorCache = origOutputCache + }() + + RegisterInputConfigCreator("test-input", func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typeName: "test-input", action: action, description: "test"}, nil + }) + RegisterOutputConfigCreator("test-output", func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typeName: "test-output", action: action, description: "test"}, nil + }) + + inst, _ := NewInstance() + + // JSON with comments and trailing commas (hujson format) + configJSON := `{ + // This is a comment + "input": [{"type": "test-input", "action": "add", "args": {}}], + "output": [{"type": "test-output", "action": "output", "args": {}},], + }` + + if err := inst.InitConfigFromBytes([]byte(configJSON)); err != nil { + t.Errorf("InitConfigFromBytes with comments error = %v", err) + } +} + +func TestInstanceInitConfigFromBytesInvalidJSON(t *testing.T) { + inst, _ := NewInstance() + if err := inst.InitConfigFromBytes([]byte(`{invalid json`)); err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestInstanceInitConfigFromFile(t *testing.T) { + origInputCache := inputConfigCreatorCache + origOutputCache := outputConfigCreatorCache + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) + defer func() { + inputConfigCreatorCache = origInputCache + outputConfigCreatorCache = origOutputCache + }() + + RegisterInputConfigCreator("test-input", func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typeName: "test-input", action: action, description: "test"}, nil + }) + RegisterOutputConfigCreator("test-output", func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typeName: "test-output", action: action, description: "test"}, nil + }) + + // Write config to temp file + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + configJSON := `{ + "input": [{"type": "test-input", "action": "add", "args": {}}], + "output": [{"type": "test-output", "action": "output", "args": {}}] + }` + if err := os.WriteFile(configPath, []byte(configJSON), 0644); err != nil { + t.Fatal(err) + } + + inst, _ := NewInstance() + if err := inst.InitConfig(configPath); err != nil { + t.Errorf("InitConfig error = %v", err) + } +} + +func TestInstanceInitConfigFromFileNotFound(t *testing.T) { + inst, _ := NewInstance() + if err := inst.InitConfig("/nonexistent/path/config.json"); err == nil { + t.Error("expected error for nonexistent file") + } +} + +func TestInstanceInitConfigFromURL(t *testing.T) { + origInputCache := inputConfigCreatorCache + origOutputCache := outputConfigCreatorCache + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) + defer func() { + inputConfigCreatorCache = origInputCache + outputConfigCreatorCache = origOutputCache + }() + + RegisterInputConfigCreator("test-input", func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typeName: "test-input", action: action, description: "test"}, nil + }) + RegisterOutputConfigCreator("test-output", func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typeName: "test-output", action: action, description: "test"}, nil + }) + + configJSON := `{ + "input": [{"type": "test-input", "action": "add", "args": {}}], + "output": [{"type": "test-output", "action": "output", "args": {}}] + }` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(configJSON)) + })) + defer server.Close() + + inst, _ := NewInstance() + if err := inst.InitConfig(server.URL); err != nil { + t.Errorf("InitConfig from URL error = %v", err) + } +} + +func TestInstanceInitConfigFromURLError(t *testing.T) { + inst, _ := NewInstance() + if err := inst.InitConfig("http://invalid-host-that-does-not-exist.example.com"); err == nil { + t.Error("expected error for invalid URL") + } +} + +func TestInstanceRunInputMultiple(t *testing.T) { + inst, _ := NewInstance() + + callOrder := make([]string, 0) + + ic1 := &mockInputConverter{ + typeName: "test1", action: ActionAdd, description: "test1", + inputFn: func(c Container) (Container, error) { + callOrder = append(callOrder, "input1") + return c, nil + }, + } + ic2 := &mockInputConverter{ + typeName: "test2", action: ActionAdd, description: "test2", + inputFn: func(c Container) (Container, error) { + callOrder = append(callOrder, "input2") + return c, nil + }, + } + + inst.AddInput(ic1) + inst.AddInput(ic2) + + container := NewContainer() + if err := inst.RunInput(container); err != nil { + t.Errorf("RunInput error = %v", err) + } + + if len(callOrder) != 2 || callOrder[0] != "input1" || callOrder[1] != "input2" { + t.Errorf("unexpected call order: %v", callOrder) + } +} + +func TestInstanceRunOutputMultiple(t *testing.T) { + inst, _ := NewInstance() + + callOrder := make([]string, 0) + + oc1 := &mockOutputConverter{ + typeName: "test1", action: ActionOutput, description: "test1", + outputFn: func(c Container) error { + callOrder = append(callOrder, "output1") + return nil + }, + } + oc2 := &mockOutputConverter{ + typeName: "test2", action: ActionOutput, description: "test2", + outputFn: func(c Container) error { + callOrder = append(callOrder, "output2") + return nil + }, + } + + inst.AddOutput(oc1) + inst.AddOutput(oc2) + + container := NewContainer() + if err := inst.RunOutput(container); err != nil { + t.Errorf("RunOutput error = %v", err) + } + + if len(callOrder) != 2 || callOrder[0] != "output1" || callOrder[1] != "output2" { + t.Errorf("unexpected call order: %v", callOrder) + } +} + +func TestInstanceRunInputDirectError(t *testing.T) { + inst, _ := NewInstance() + + ic := &mockInputConverter{ + typeName: "test", action: ActionAdd, description: "test", + inputFn: func(c Container) (Container, error) { + return nil, errors.New("run input error") + }, + } + + inst.AddInput(ic) + + container := NewContainer() + if err := inst.RunInput(container); err == nil { + t.Error("expected RunInput error") + } +} + +func TestInstanceRunOutputDirectError(t *testing.T) { + inst, _ := NewInstance() + + oc := &mockOutputConverter{ + typeName: "test", action: ActionOutput, description: "test", + outputFn: func(c Container) error { + return errors.New("run output error") + }, + } + + inst.AddOutput(oc) + + container := NewContainer() + if err := inst.RunOutput(container); err == nil { + t.Error("expected RunOutput error") + } +} diff --git a/lib/lib_test.go b/lib/lib_test.go new file mode 100644 index 00000000000..0f2f5cecbab --- /dev/null +++ b/lib/lib_test.go @@ -0,0 +1,70 @@ +package lib + +import ( + "testing" +) + +func TestConstants(t *testing.T) { + if ActionAdd != "add" { + t.Errorf("expected ActionAdd to be 'add', got %q", ActionAdd) + } + if ActionRemove != "remove" { + t.Errorf("expected ActionRemove to be 'remove', got %q", ActionRemove) + } + if ActionOutput != "output" { + t.Errorf("expected ActionOutput to be 'output', got %q", ActionOutput) + } + if IPv4 != "ipv4" { + t.Errorf("expected IPv4 to be 'ipv4', got %q", IPv4) + } + if IPv6 != "ipv6" { + t.Errorf("expected IPv6 to be 'ipv6', got %q", IPv6) + } + if CaseRemovePrefix != 0 { + t.Errorf("expected CaseRemovePrefix to be 0, got %d", CaseRemovePrefix) + } + if CaseRemoveEntry != 1 { + t.Errorf("expected CaseRemoveEntry to be 1, got %d", CaseRemoveEntry) + } +} + +func TestActionsRegistry(t *testing.T) { + expected := map[Action]bool{ + ActionAdd: true, + ActionRemove: true, + ActionOutput: true, + } + for action, val := range expected { + if ActionsRegistry[action] != val { + t.Errorf("expected ActionsRegistry[%q] to be %v", action, val) + } + } + if len(ActionsRegistry) != len(expected) { + t.Errorf("expected ActionsRegistry to have %d entries, got %d", len(expected), len(ActionsRegistry)) + } +} + +func TestIgnoreIPv4(t *testing.T) { + result := IgnoreIPv4() + if result != IPv4 { + t.Errorf("expected IgnoreIPv4() to return IPv4, got %q", result) + } +} + +func TestIgnoreIPv6(t *testing.T) { + result := IgnoreIPv6() + if result != IPv6 { + t.Errorf("expected IgnoreIPv6() to return IPv6, got %q", result) + } +} + +func TestIgnoreIPOptionType(t *testing.T) { + var opt IgnoreIPOption = IgnoreIPv4 + if opt() != IPv4 { + t.Error("IgnoreIPOption function should return IPv4") + } + opt = IgnoreIPv6 + if opt() != IPv6 { + t.Error("IgnoreIPOption function should return IPv6") + } +}