Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions cli/cmd/import/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ import (
"github.com/agntcy/dir/cli/presenter"
ctxUtils "github.com/agntcy/dir/cli/util/context"
"github.com/agntcy/dir/importer/config"
_ "github.com/agntcy/dir/importer/mcp" // Import MCP importer to trigger its init() function for auto-registration.
"github.com/agntcy/dir/importer/factory"
"github.com/agntcy/dir/importer/types"
"github.com/agntcy/dir/importer/types/factory"
"github.com/spf13/cobra"
)

const cidOutputFilePerm = 0o600

var Command = &cobra.Command{
Use: "import",
Short: "Import records from external registries",
Expand All @@ -32,7 +33,7 @@ The import command fetches records from the specified registry and pushes
them to DIR.

Examples:
# Import from MCP registry
# Import from MCP registry with default enrichment configuration
dirctl import --type=mcp --url=https://registry.modelcontextprotocol.io

# Import with filters
Expand All @@ -42,9 +43,6 @@ Examples:
# Preview without importing
dirctl import --type=mcp --url=https://registry.modelcontextprotocol.io --dry-run

# Import with default enrichment configuration
dirctl import --type=mcp --url=https://registry.modelcontextprotocol.io

# Use custom MCPHost configuration and prompt templates
dirctl import --type=mcp --url=https://registry.modelcontextprotocol.io \
--enrich-skills-prompt=/path/to/custom-skills-prompt.md \
Expand Down Expand Up @@ -84,7 +82,7 @@ func runImport(cmd *cobra.Command) error {
}

// Create importer instance from pre-initialized factory
importer, err := factory.Create(c, opts.Config)
importer, err := factory.Create(cmd.Context(), c, opts.Config)
if err != nil {
return fmt.Errorf("failed to create importer: %w", err)
}
Expand All @@ -102,17 +100,21 @@ func runImport(cmd *cobra.Command) error {

presenter.Printf(cmd, "\n")

result, err := importer.Run(cmd.Context(), opts.Config)
if err != nil {
return fmt.Errorf("import failed: %w", err)
var result *types.ImportResult

if opts.DryRun {
result = importer.DryRun(cmd.Context())
} else {
result = importer.Run(cmd.Context())
}

// Print summary
printSummary(cmd, result)

// Write CIDs to output file if specified
if opts.OutputCIDFile != "" && len(result.ImportedCIDs) > 0 {
if err := writeCIDsToFile(opts.OutputCIDFile, result.ImportedCIDs); err != nil {
content := strings.Join(result.ImportedCIDs, "\n") + "\n"

if err := os.WriteFile(opts.OutputCIDFile, []byte(content), cidOutputFilePerm); err != nil {
return fmt.Errorf("failed to write CIDs to file: %w", err)
}

Expand All @@ -122,13 +124,6 @@ func runImport(cmd *cobra.Command) error {
return nil
}

// writeCIDsToFile writes a list of CIDs to a file, one per line.
func writeCIDsToFile(path string, cids []string) error {
content := strings.Join(cids, "\n") + "\n"

return os.WriteFile(path, []byte(content), 0o600) //nolint:mnd,wrapcheck
}

func printSummary(cmd *cobra.Command, result *types.ImportResult) {
maxErrors := 10

Expand Down
16 changes: 7 additions & 9 deletions cli/cmd/import/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package importcmd
import (
signcmd "github.com/agntcy/dir/cli/cmd/sign"
"github.com/agntcy/dir/importer/config"
"github.com/agntcy/dir/importer/enricher"
enricherconfig "github.com/agntcy/dir/importer/enricher/config"
scannerconfig "github.com/agntcy/dir/importer/scanner/config"
)

Expand All @@ -31,16 +31,14 @@ func init() {
flags.BoolVar(&opts.Force, "force", false, "Force push even if record already exists")
flags.BoolVar(&opts.Debug, "debug", false, "Enable debug output for deduplication and validation failures")

// Enrichment is mandatory - these flags configure the enrichment process
flags.StringVar(&opts.EnricherConfigFile, "enrich-config", enricher.DefaultConfigFile, "Path to MCPHost configuration file (mcphost.json)")
flags.StringVar(&opts.EnricherSkillsPromptTemplate, "enrich-skills-prompt", "", "Optional: path to custom skills prompt template file or inline prompt (empty = use default)")
flags.StringVar(&opts.EnricherDomainsPromptTemplate, "enrich-domains-prompt", "", "Optional: path to custom domains prompt template file or inline prompt (empty = use default)")

// Rate limiting for LLM API calls
flags.IntVar(&opts.EnricherRequestsPerMinute, "enrich-rate-limit", enricher.DefaultRequestsPerMinute, "Maximum LLM API requests per minute (to avoid rate limit errors)")
// Enrichment flags
flags.StringVar(&opts.Enricher.ConfigFile, "enrich-config", enricherconfig.DefaultConfigFile, "Path to MCPHost configuration file (mcphost.json)")
flags.StringVar(&opts.Enricher.SkillsPromptTemplate, "enrich-skills-prompt", "", "Path to custom skills prompt template file")
flags.StringVar(&opts.Enricher.DomainsPromptTemplate, "enrich-domains-prompt", "", "Path to custom domains prompt template file")
flags.IntVar(&opts.Enricher.RequestsPerMinute, "enrich-rate-limit", enricherconfig.DefaultRequestsPerMinute, "Maximum LLM API requests per minute (to avoid rate limit errors)")

// Scanner flags
flags.BoolVar(&opts.Scanner.Enabled, "scanner-enabled", false, "Run all registered security scanners on each record")
flags.BoolVar(&opts.Scanner.Enabled, "scanner-enabled", scannerconfig.DefaultScannerEnabled, "Run all registered security scanners on each record")
flags.DurationVar(&opts.Scanner.Timeout, "scanner-timeout", scannerconfig.DefaultTimeout, "Timeout per record scan")
flags.StringVar(&opts.Scanner.CLIPath, "scanner-cli-path", scannerconfig.DefaultCLIPath, "Path to mcp-scanner binary (default: mcp-scanner from PATH)")
flags.BoolVar(&opts.Scanner.FailOnError, "scanner-fail-on-error", scannerconfig.DefaultFailOnError, "Do not import records that have error-severity scanner findings")
Expand Down
13 changes: 2 additions & 11 deletions importer/Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@

version: "3"

vars:
# TODO: Switch to upstream cisco-ai-defense/mcp-scanner once the JSON fix is merged
MCP_SCANNER_REPO: "adamtagscherer/mcp-scanner"
MCP_SCANNER_BRANCH: "fix/handle-markdown-json-in-classifier"

tasks:
default:
cmd: echo "Run the main Taskfile instead of this one."
Expand All @@ -17,10 +12,6 @@ tasks:
deps:
- task: deps:uv
cmds:
- cmd: echo "Installing mcp-scanner from {{.MCP_SCANNER_REPO}}@{{.MCP_SCANNER_BRANCH}}..."
- cmd: echo "Installing mcp-scanner"
- cmd: >-
{{.UV_BIN}} tool install
--force
"cisco-ai-mcp-scanner @ git+https://github.com/{{.MCP_SCANNER_REPO}}.git@{{.MCP_SCANNER_BRANCH}}"
--with yara-python
--python-preference only-managed
{{.UV_BIN}} tool install --python 3.13 cisco-ai-mcp-scanner
23 changes: 5 additions & 18 deletions importer/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
corev1 "github.com/agntcy/dir/api/core/v1"
searchv1 "github.com/agntcy/dir/api/search/v1"
"github.com/agntcy/dir/client/streaming"
enricherconfig "github.com/agntcy/dir/importer/enricher/config"
scannerconfig "github.com/agntcy/dir/importer/scanner/config"
)

Expand All @@ -20,12 +21,6 @@ type RegistryType string
const (
// RegistryTypeMCP represents the Model Context Protocol registry.
RegistryTypeMCP RegistryType = "mcp"

// FUTURE: RegistryTypeNANDA represents the NANDA registry.
// RegistryTypeNANDA RegistryType = "nanda".

// FUTURE:RegistryTypeA2A represents the Agent-to-Agent protocol registry.
// RegistryTypeA2A RegistryType = "a2a".
)

// ClientInterface defines the interface for the DIR client used by importers.
Expand All @@ -45,22 +40,14 @@ type Config struct {
RegistryURL string // Base URL of the registry
Filters map[string]string // Registry-specific filters
Limit int // Number of records to import (default: 0 for all)
Concurrency int // Number of concurrent workers (default: 1)
DryRun bool // If true, preview without actually importing
SignFunc SignFunc // Function to sign records (if set, signing is enabled)

// Enrichment is mandatory - these fields are always used
EnricherConfigFile string // Path to MCPHost configuration file (e.g., mcphost.json)
EnricherSkillsPromptTemplate string // Optional: path to custom skills prompt template or inline prompt (empty = use default)
EnricherDomainsPromptTemplate string // Optional: path to custom domains prompt template or inline prompt (empty = use default)

// Rate limiting for LLM API calls to avoid provider rate limit errors
EnricherRequestsPerMinute int // Maximum LLM requests per minute (0 = use default of 10)

Force bool // If true, push even if record already exists
Debug bool // If true, enable verbose debug output

Scanner scannerconfig.Config // Scanner configuration
Enricher enricherconfig.Config // Configuration for the enricher pipeline stage
Scanner scannerconfig.Config // Configuration for the scanner pipeline stage
}

// Validate checks if the configuration is valid.
Expand All @@ -73,8 +60,8 @@ func (c *Config) Validate() error {
return errors.New("registry URL is required")
}

if c.Concurrency <= 0 {
c.Concurrency = 1 // Set default concurrency
if err := c.Enricher.Validate(); err != nil {
return fmt.Errorf("enricher configuration is invalid: %w", err)
}

if err := c.Scanner.Validate(); err != nil {
Expand Down
101 changes: 40 additions & 61 deletions importer/config/config_test.go
Original file line number Diff line number Diff line change
@@ -1,77 +1,56 @@
// Copyright AGNTCY Contributors (https://github.com/agntcy)
// SPDX-License-Identifier: Apache-2.0

//nolint:nilnil
package config

import (
"os"
"path/filepath"
"testing"

enricherconfig "github.com/agntcy/dir/importer/enricher/config"
scannerconfig "github.com/agntcy/dir/importer/scanner/config"
)

func TestConfig_Validate(t *testing.T) {
tests := []struct {
name string
config Config
wantErr bool
errMsg string
}{
{
name: "valid config",
config: Config{
RegistryType: RegistryTypeMCP,
RegistryURL: "https://registry.example.com",
Concurrency: 10,
},
wantErr: false,
},
{
name: "missing registry type",
config: Config{
RegistryURL: "https://registry.example.com",
Concurrency: 10,
},
wantErr: true,
errMsg: "registry type is required",
},
{
name: "missing registry URL",
config: Config{
RegistryType: RegistryTypeMCP,
Concurrency: 10,
},
wantErr: true,
errMsg: "registry URL is required",
},
{
name: "zero concurrency sets default",
config: Config{
RegistryType: RegistryTypeMCP,
RegistryURL: "https://registry.example.com",
Concurrency: 0,
},
wantErr: false,
},
func TestConfig_Validate_MissingRegistryType(t *testing.T) {
t.Parallel()

c := Config{RegistryURL: "https://x.com"}
if err := c.Validate(); err == nil {
t.Fatal("expected error for empty RegistryType")
}
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
func TestConfig_Validate_MissingURL(t *testing.T) {
t.Parallel()

return
}
c := Config{RegistryType: RegistryTypeMCP}
if err := c.Validate(); err == nil {
t.Fatal("expected error for empty RegistryURL")
}
}

func TestConfig_Validate_OK(t *testing.T) {
t.Parallel()

if tt.wantErr && err.Error() != tt.errMsg {
t.Errorf("Config.Validate() error message = %v, want %v", err.Error(), tt.errMsg)
}
dir := t.TempDir()

cfgPath := filepath.Join(dir, "mcphost.json")
if err := os.WriteFile(cfgPath, []byte(`{}`), 0o600); err != nil {
t.Fatal(err)
}

c := Config{
RegistryType: RegistryTypeMCP,
RegistryURL: "https://registry.example.com",
Enricher: enricherconfig.Config{
ConfigFile: cfgPath,
RequestsPerMinute: 1,
},
Scanner: scannerconfig.Config{Enabled: false},
}

// Check that default concurrency is set when invalid
if !tt.wantErr && tt.name == "zero concurrency sets default" {
if tt.config.Concurrency != 1 {
t.Errorf("Config.Validate() did not set default concurrency, got %d, want 1", tt.config.Concurrency)
}
}
})
if err := c.Validate(); err != nil {
t.Fatalf("Validate: %v", err)
}
}
14 changes: 8 additions & 6 deletions importer/pipeline/dedup.go → importer/dedup/dedup.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright AGNTCY Contributors (https://github.com/agntcy)
// SPDX-License-Identifier: Apache-2.0

package pipeline
package dedup

import (
"context"
Expand All @@ -12,6 +12,8 @@ import (
corev1 "github.com/agntcy/dir/api/core/v1"
searchv1 "github.com/agntcy/dir/api/search/v1"
"github.com/agntcy/dir/importer/config"
"github.com/agntcy/dir/importer/shared"
"github.com/agntcy/dir/importer/types"
"github.com/agntcy/dir/utils/logging"
mcpapiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
)
Expand Down Expand Up @@ -127,7 +129,7 @@ func (c *MCPDuplicateChecker) buildCache(ctx context.Context) error {
c.mu.Lock()

for _, record := range records {
nameVersion, err := ExtractNameVersion(record)
nameVersion, err := shared.ExtractNameVersion(record)
if err != nil {
continue
}
Expand Down Expand Up @@ -171,8 +173,8 @@ func (c *MCPDuplicateChecker) buildCache(ctx context.Context) error {
// It filters out duplicate records from the input channel and returns a channel
// with only non-duplicate records. It tracks only the skipped (duplicate) count.
// The transform stage will track the total records that are actually processed.
func (c *MCPDuplicateChecker) FilterDuplicates(ctx context.Context, inputCh <-chan any, result *Result) <-chan any {
outputCh := make(chan any)
func (c *MCPDuplicateChecker) FilterDuplicates(ctx context.Context, inputCh <-chan mcpapiv0.ServerResponse, result *types.Result) <-chan mcpapiv0.ServerResponse {
outputCh := make(chan mcpapiv0.ServerResponse)

go func() {
defer close(outputCh)
Expand All @@ -188,10 +190,10 @@ func (c *MCPDuplicateChecker) FilterDuplicates(ctx context.Context, inputCh <-ch

// Check if duplicate
if c.isDuplicate(source) {
result.mu.Lock()
result.Mu.Lock()
result.TotalRecords++
result.SkippedCount++
result.mu.Unlock()
result.Mu.Unlock()

continue
}
Expand Down
Loading
Loading