Skip to content

Commit 540777d

Browse files
refactor: importer
Signed-off-by: Tagscherer Ádám <adam.tagscherer@gmail.com>
1 parent 39de7d6 commit 540777d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2609
-5295
lines changed

cli/cmd/import/import.go

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ import (
1414
"github.com/agntcy/dir/cli/presenter"
1515
ctxUtils "github.com/agntcy/dir/cli/util/context"
1616
"github.com/agntcy/dir/importer/config"
17-
_ "github.com/agntcy/dir/importer/mcp" // Import MCP importer to trigger its init() function for auto-registration.
17+
"github.com/agntcy/dir/importer/factory"
1818
"github.com/agntcy/dir/importer/types"
19-
"github.com/agntcy/dir/importer/types/factory"
2019
"github.com/spf13/cobra"
2120
)
2221

22+
const cidOutputFilePerm = 0o600
23+
2324
var Command = &cobra.Command{
2425
Use: "import",
2526
Short: "Import records from external registries",
@@ -32,7 +33,7 @@ The import command fetches records from the specified registry and pushes
3233
them to DIR.
3334
3435
Examples:
35-
# Import from MCP registry
36+
# Import from MCP registry with default enrichment configuration
3637
dirctl import --type=mcp --url=https://registry.modelcontextprotocol.io
3738
3839
# Import with filters
@@ -42,9 +43,6 @@ Examples:
4243
# Preview without importing
4344
dirctl import --type=mcp --url=https://registry.modelcontextprotocol.io --dry-run
4445
45-
# Import with default enrichment configuration
46-
dirctl import --type=mcp --url=https://registry.modelcontextprotocol.io
47-
4846
# Use custom MCPHost configuration and prompt templates
4947
dirctl import --type=mcp --url=https://registry.modelcontextprotocol.io \
5048
--enrich-skills-prompt=/path/to/custom-skills-prompt.md \
@@ -84,7 +82,7 @@ func runImport(cmd *cobra.Command) error {
8482
}
8583

8684
// Create importer instance from pre-initialized factory
87-
importer, err := factory.Create(c, opts.Config)
85+
importer, err := factory.Create(cmd.Context(), c, opts.Config)
8886
if err != nil {
8987
return fmt.Errorf("failed to create importer: %w", err)
9088
}
@@ -102,17 +100,21 @@ func runImport(cmd *cobra.Command) error {
102100

103101
presenter.Printf(cmd, "\n")
104102

105-
result, err := importer.Run(cmd.Context(), opts.Config)
106-
if err != nil {
107-
return fmt.Errorf("import failed: %w", err)
103+
var result *types.ImportResult
104+
105+
if opts.DryRun {
106+
result = importer.DryRun(cmd.Context())
107+
} else {
108+
result = importer.Run(cmd.Context())
108109
}
109110

110-
// Print summary
111111
printSummary(cmd, result)
112112

113113
// Write CIDs to output file if specified
114114
if opts.OutputCIDFile != "" && len(result.ImportedCIDs) > 0 {
115-
if err := writeCIDsToFile(opts.OutputCIDFile, result.ImportedCIDs); err != nil {
115+
content := strings.Join(result.ImportedCIDs, "\n") + "\n"
116+
117+
if err := os.WriteFile(opts.OutputCIDFile, []byte(content), cidOutputFilePerm); err != nil {
116118
return fmt.Errorf("failed to write CIDs to file: %w", err)
117119
}
118120

@@ -122,13 +124,6 @@ func runImport(cmd *cobra.Command) error {
122124
return nil
123125
}
124126

125-
// writeCIDsToFile writes a list of CIDs to a file, one per line.
126-
func writeCIDsToFile(path string, cids []string) error {
127-
content := strings.Join(cids, "\n") + "\n"
128-
129-
return os.WriteFile(path, []byte(content), 0o600) //nolint:mnd,wrapcheck
130-
}
131-
132127
func printSummary(cmd *cobra.Command, result *types.ImportResult) {
133128
maxErrors := 10
134129

cli/cmd/import/options.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ package importcmd
66
import (
77
signcmd "github.com/agntcy/dir/cli/cmd/sign"
88
"github.com/agntcy/dir/importer/config"
9-
"github.com/agntcy/dir/importer/enricher"
9+
enricherconfig "github.com/agntcy/dir/importer/enricher/config"
1010
scannerconfig "github.com/agntcy/dir/importer/scanner/config"
1111
)
1212

@@ -31,16 +31,14 @@ func init() {
3131
flags.BoolVar(&opts.Force, "force", false, "Force push even if record already exists")
3232
flags.BoolVar(&opts.Debug, "debug", false, "Enable debug output for deduplication and validation failures")
3333

34-
// Enrichment is mandatory - these flags configure the enrichment process
35-
flags.StringVar(&opts.EnricherConfigFile, "enrich-config", enricher.DefaultConfigFile, "Path to MCPHost configuration file (mcphost.json)")
36-
flags.StringVar(&opts.EnricherSkillsPromptTemplate, "enrich-skills-prompt", "", "Optional: path to custom skills prompt template file or inline prompt (empty = use default)")
37-
flags.StringVar(&opts.EnricherDomainsPromptTemplate, "enrich-domains-prompt", "", "Optional: path to custom domains prompt template file or inline prompt (empty = use default)")
38-
39-
// Rate limiting for LLM API calls
40-
flags.IntVar(&opts.EnricherRequestsPerMinute, "enrich-rate-limit", enricher.DefaultRequestsPerMinute, "Maximum LLM API requests per minute (to avoid rate limit errors)")
34+
// Enrichment flags
35+
flags.StringVar(&opts.Enricher.ConfigFile, "enrich-config", enricherconfig.DefaultConfigFile, "Path to MCPHost configuration file (mcphost.json)")
36+
flags.StringVar(&opts.Enricher.SkillsPromptTemplate, "enrich-skills-prompt", "", "Path to custom skills prompt template file")
37+
flags.StringVar(&opts.Enricher.DomainsPromptTemplate, "enrich-domains-prompt", "", "Path to custom domains prompt template file")
38+
flags.IntVar(&opts.Enricher.RequestsPerMinute, "enrich-rate-limit", enricherconfig.DefaultRequestsPerMinute, "Maximum LLM API requests per minute (to avoid rate limit errors)")
4139

4240
// Scanner flags
43-
flags.BoolVar(&opts.Scanner.Enabled, "scanner-enabled", false, "Run all registered security scanners on each record")
41+
flags.BoolVar(&opts.Scanner.Enabled, "scanner-enabled", scannerconfig.DefaultScannerEnabled, "Run all registered security scanners on each record")
4442
flags.DurationVar(&opts.Scanner.Timeout, "scanner-timeout", scannerconfig.DefaultTimeout, "Timeout per record scan")
4543
flags.StringVar(&opts.Scanner.CLIPath, "scanner-cli-path", scannerconfig.DefaultCLIPath, "Path to mcp-scanner binary (default: mcp-scanner from PATH)")
4644
flags.BoolVar(&opts.Scanner.FailOnError, "scanner-fail-on-error", scannerconfig.DefaultFailOnError, "Do not import records that have error-severity scanner findings")

importer/Taskfile.yml

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33

44
version: "3"
55

6-
vars:
7-
# TODO: Switch to upstream cisco-ai-defense/mcp-scanner once the JSON fix is merged
8-
MCP_SCANNER_REPO: "adamtagscherer/mcp-scanner"
9-
MCP_SCANNER_BRANCH: "fix/handle-markdown-json-in-classifier"
10-
116
tasks:
127
default:
138
cmd: echo "Run the main Taskfile instead of this one."
@@ -17,10 +12,6 @@ tasks:
1712
deps:
1813
- task: deps:uv
1914
cmds:
20-
- cmd: echo "Installing mcp-scanner from {{.MCP_SCANNER_REPO}}@{{.MCP_SCANNER_BRANCH}}..."
15+
- cmd: echo "Installing mcp-scanner"
2116
- cmd: >-
22-
{{.UV_BIN}} tool install
23-
--force
24-
"cisco-ai-mcp-scanner @ git+https://github.com/{{.MCP_SCANNER_REPO}}.git@{{.MCP_SCANNER_BRANCH}}"
25-
--with yara-python
26-
--python-preference only-managed
17+
{{.UV_BIN}} tool install --python 3.13 cisco-ai-mcp-scanner

importer/config/config.go

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
corev1 "github.com/agntcy/dir/api/core/v1"
1212
searchv1 "github.com/agntcy/dir/api/search/v1"
1313
"github.com/agntcy/dir/client/streaming"
14+
enricherconfig "github.com/agntcy/dir/importer/enricher/config"
1415
scannerconfig "github.com/agntcy/dir/importer/scanner/config"
1516
)
1617

@@ -20,12 +21,6 @@ type RegistryType string
2021
const (
2122
// RegistryTypeMCP represents the Model Context Protocol registry.
2223
RegistryTypeMCP RegistryType = "mcp"
23-
24-
// FUTURE: RegistryTypeNANDA represents the NANDA registry.
25-
// RegistryTypeNANDA RegistryType = "nanda".
26-
27-
// FUTURE:RegistryTypeA2A represents the Agent-to-Agent protocol registry.
28-
// RegistryTypeA2A RegistryType = "a2a".
2924
)
3025

3126
// ClientInterface defines the interface for the DIR client used by importers.
@@ -45,22 +40,14 @@ type Config struct {
4540
RegistryURL string // Base URL of the registry
4641
Filters map[string]string // Registry-specific filters
4742
Limit int // Number of records to import (default: 0 for all)
48-
Concurrency int // Number of concurrent workers (default: 1)
4943
DryRun bool // If true, preview without actually importing
5044
SignFunc SignFunc // Function to sign records (if set, signing is enabled)
5145

52-
// Enrichment is mandatory - these fields are always used
53-
EnricherConfigFile string // Path to MCPHost configuration file (e.g., mcphost.json)
54-
EnricherSkillsPromptTemplate string // Optional: path to custom skills prompt template or inline prompt (empty = use default)
55-
EnricherDomainsPromptTemplate string // Optional: path to custom domains prompt template or inline prompt (empty = use default)
56-
57-
// Rate limiting for LLM API calls to avoid provider rate limit errors
58-
EnricherRequestsPerMinute int // Maximum LLM requests per minute (0 = use default of 10)
59-
6046
Force bool // If true, push even if record already exists
6147
Debug bool // If true, enable verbose debug output
6248

63-
Scanner scannerconfig.Config // Scanner configuration
49+
Enricher enricherconfig.Config // Configuration for the enricher pipeline stage
50+
Scanner scannerconfig.Config // Configuration for the scanner pipeline stage
6451
}
6552

6653
// Validate checks if the configuration is valid.
@@ -73,8 +60,8 @@ func (c *Config) Validate() error {
7360
return errors.New("registry URL is required")
7461
}
7562

76-
if c.Concurrency <= 0 {
77-
c.Concurrency = 1 // Set default concurrency
63+
if err := c.Enricher.Validate(); err != nil {
64+
return fmt.Errorf("enricher configuration is invalid: %w", err)
7865
}
7966

8067
if err := c.Scanner.Validate(); err != nil {

importer/config/config_test.go

Lines changed: 40 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,56 @@
11
// Copyright AGNTCY Contributors (https://github.com/agntcy)
22
// SPDX-License-Identifier: Apache-2.0
33

4-
//nolint:nilnil
54
package config
65

76
import (
7+
"os"
8+
"path/filepath"
89
"testing"
10+
11+
enricherconfig "github.com/agntcy/dir/importer/enricher/config"
12+
scannerconfig "github.com/agntcy/dir/importer/scanner/config"
913
)
1014

11-
func TestConfig_Validate(t *testing.T) {
12-
tests := []struct {
13-
name string
14-
config Config
15-
wantErr bool
16-
errMsg string
17-
}{
18-
{
19-
name: "valid config",
20-
config: Config{
21-
RegistryType: RegistryTypeMCP,
22-
RegistryURL: "https://registry.example.com",
23-
Concurrency: 10,
24-
},
25-
wantErr: false,
26-
},
27-
{
28-
name: "missing registry type",
29-
config: Config{
30-
RegistryURL: "https://registry.example.com",
31-
Concurrency: 10,
32-
},
33-
wantErr: true,
34-
errMsg: "registry type is required",
35-
},
36-
{
37-
name: "missing registry URL",
38-
config: Config{
39-
RegistryType: RegistryTypeMCP,
40-
Concurrency: 10,
41-
},
42-
wantErr: true,
43-
errMsg: "registry URL is required",
44-
},
45-
{
46-
name: "zero concurrency sets default",
47-
config: Config{
48-
RegistryType: RegistryTypeMCP,
49-
RegistryURL: "https://registry.example.com",
50-
Concurrency: 0,
51-
},
52-
wantErr: false,
53-
},
15+
func TestConfig_Validate_MissingRegistryType(t *testing.T) {
16+
t.Parallel()
17+
18+
c := Config{RegistryURL: "https://x.com"}
19+
if err := c.Validate(); err == nil {
20+
t.Fatal("expected error for empty RegistryType")
5421
}
22+
}
5523

56-
for _, tt := range tests {
57-
t.Run(tt.name, func(t *testing.T) {
58-
err := tt.config.Validate()
59-
if (err != nil) != tt.wantErr {
60-
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
24+
func TestConfig_Validate_MissingURL(t *testing.T) {
25+
t.Parallel()
6126

62-
return
63-
}
27+
c := Config{RegistryType: RegistryTypeMCP}
28+
if err := c.Validate(); err == nil {
29+
t.Fatal("expected error for empty RegistryURL")
30+
}
31+
}
32+
33+
func TestConfig_Validate_OK(t *testing.T) {
34+
t.Parallel()
6435

65-
if tt.wantErr && err.Error() != tt.errMsg {
66-
t.Errorf("Config.Validate() error message = %v, want %v", err.Error(), tt.errMsg)
67-
}
36+
dir := t.TempDir()
37+
38+
cfgPath := filepath.Join(dir, "mcphost.json")
39+
if err := os.WriteFile(cfgPath, []byte(`{}`), 0o600); err != nil {
40+
t.Fatal(err)
41+
}
42+
43+
c := Config{
44+
RegistryType: RegistryTypeMCP,
45+
RegistryURL: "https://registry.example.com",
46+
Enricher: enricherconfig.Config{
47+
ConfigFile: cfgPath,
48+
RequestsPerMinute: 1,
49+
},
50+
Scanner: scannerconfig.Config{Enabled: false},
51+
}
6852

69-
// Check that default concurrency is set when invalid
70-
if !tt.wantErr && tt.name == "zero concurrency sets default" {
71-
if tt.config.Concurrency != 1 {
72-
t.Errorf("Config.Validate() did not set default concurrency, got %d, want 1", tt.config.Concurrency)
73-
}
74-
}
75-
})
53+
if err := c.Validate(); err != nil {
54+
t.Fatalf("Validate: %v", err)
7655
}
7756
}
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright AGNTCY Contributors (https://github.com/agntcy)
22
// SPDX-License-Identifier: Apache-2.0
33

4-
package pipeline
4+
package dedup
55

66
import (
77
"context"
@@ -12,6 +12,8 @@ import (
1212
corev1 "github.com/agntcy/dir/api/core/v1"
1313
searchv1 "github.com/agntcy/dir/api/search/v1"
1414
"github.com/agntcy/dir/importer/config"
15+
"github.com/agntcy/dir/importer/shared"
16+
"github.com/agntcy/dir/importer/types"
1517
"github.com/agntcy/dir/utils/logging"
1618
mcpapiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
1719
)
@@ -127,7 +129,7 @@ func (c *MCPDuplicateChecker) buildCache(ctx context.Context) error {
127129
c.mu.Lock()
128130

129131
for _, record := range records {
130-
nameVersion, err := ExtractNameVersion(record)
132+
nameVersion, err := shared.ExtractNameVersion(record)
131133
if err != nil {
132134
continue
133135
}
@@ -171,8 +173,8 @@ func (c *MCPDuplicateChecker) buildCache(ctx context.Context) error {
171173
// It filters out duplicate records from the input channel and returns a channel
172174
// with only non-duplicate records. It tracks only the skipped (duplicate) count.
173175
// The transform stage will track the total records that are actually processed.
174-
func (c *MCPDuplicateChecker) FilterDuplicates(ctx context.Context, inputCh <-chan any, result *Result) <-chan any {
175-
outputCh := make(chan any)
176+
func (c *MCPDuplicateChecker) FilterDuplicates(ctx context.Context, inputCh <-chan mcpapiv0.ServerResponse, result *types.Result) <-chan mcpapiv0.ServerResponse {
177+
outputCh := make(chan mcpapiv0.ServerResponse)
176178

177179
go func() {
178180
defer close(outputCh)
@@ -188,10 +190,10 @@ func (c *MCPDuplicateChecker) FilterDuplicates(ctx context.Context, inputCh <-ch
188190

189191
// Check if duplicate
190192
if c.isDuplicate(source) {
191-
result.mu.Lock()
193+
result.Mu.Lock()
192194
result.TotalRecords++
193195
result.SkippedCount++
194-
result.mu.Unlock()
196+
result.Mu.Unlock()
195197

196198
continue
197199
}

0 commit comments

Comments
 (0)