Skip to content

Commit c84aee8

Browse files
committed
fix: address review feedback (concurrency, empty layers, helpers, docs)
1 parent 6c32a87 commit c84aee8

5 files changed

Lines changed: 73 additions & 18 deletions

File tree

docs/custom-scanners.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type Scanner interface {
1717
| `Text` | `string` | Normalized text after decoding/normalization. |
1818
| `RawText` | `string` | Original text before normalization. |
1919
| `URL` | `string` | Source URL when available. |
20-
| `Mode` | `idpishield.Mode` | Current scan mode (`fast`, `balanced`, `deep`). |
20+
| `Mode` | `idpishield.Mode` | Current scan mode (`fast`, `balanced`, `deep`, `strict`). |
2121
| `IsOutputScan` | `bool` | True when called from `AssessOutput()`. |
2222
| `CurrentScore` | `int` | Score accumulated by built-ins before your scanner runs. |
2323

@@ -31,6 +31,9 @@ type Scanner interface {
3131
| `PatternID` | `string` | Optional pattern ID for audit trails. |
3232
| `Metadata` | `map[string]string` | Optional metadata for debugging. |
3333

34+
`ScanContext.Mode` controls analysis pipeline depth (including `strict` full-pipeline execution).
35+
This is different from `Config.StrictMode`, which only changes blocking thresholds.
36+
3437
## Writing Your First Scanner
3538
1. Define a scanner struct that implements `Name()` and `Scan()`.
3639
2. In `Scan()`, inspect `ctx.Text` using helper utilities.

idpishield.go

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -616,24 +616,26 @@ func (s *Shield) WithScanners(names ...string) *Shield {
616616
return s
617617
}
618618

619-
s.mu.Lock()
620-
defer s.mu.Unlock()
619+
s.mu.RLock()
620+
baseCfg := s.baseCfg
621+
registry := cloneScannerRegistry(s.scannerRegistry)
622+
s.mu.RUnlock()
621623

622624
selected := make([]Scanner, 0, len(names))
623625
for _, name := range names {
624626
key := strings.ToLower(strings.TrimSpace(name))
625627
if key == "" {
626628
continue
627629
}
628-
scanner, ok := s.scannerRegistry[key]
630+
scanner, ok := registry[key]
629631
if !ok || scanner == nil {
630632
continue
631633
}
632634
selected = append(selected, scanner)
633635
}
634636

635-
baseExtras := append([]Scanner(nil), s.baseCfg.ExtraScanners...)
636-
cfg := s.baseCfg
637+
baseExtras := append([]Scanner(nil), baseCfg.ExtraScanners...)
638+
cfg := baseCfg
637639
cfg.ExtraScanners = mergeScannersByName(baseExtras, selected)
638640

639641
resolvedCfg, err := engine.ResolveConfig(toEngineCfg(cfg))
@@ -644,8 +646,11 @@ func (s *Shield) WithScanners(names ...string) *Shield {
644646
return s
645647
}
646648

647-
s.engine = engine.New(resolvedCfg)
648-
return s
649+
return &Shield{
650+
engine: engine.New(resolvedCfg),
651+
baseCfg: cfg,
652+
scannerRegistry: registry,
653+
}
649654
}
650655

651656
// --- Functions ---
@@ -687,7 +692,11 @@ func Helpers() ScanHelpers { return ScanHelpers{} }
687692
func (h ScanHelpers) ContainsAny(text string, phrases []string) bool {
688693
lower := strings.ToLower(text)
689694
for _, phrase := range phrases {
690-
if strings.Contains(lower, strings.ToLower(strings.TrimSpace(phrase))) {
695+
trimmed := strings.ToLower(strings.TrimSpace(phrase))
696+
if trimmed == "" {
697+
continue
698+
}
699+
if strings.Contains(lower, trimmed) {
691700
return true
692701
}
693702
}
@@ -808,6 +817,17 @@ func snapshotGlobalScannerRegistry() map[string]Scanner {
808817
return out
809818
}
810819

820+
func cloneScannerRegistry(in map[string]Scanner) map[string]Scanner {
821+
if len(in) == 0 {
822+
return map[string]Scanner{}
823+
}
824+
out := make(map[string]Scanner, len(in))
825+
for k, v := range in {
826+
out[k] = v
827+
}
828+
return out
829+
}
830+
811831
func mergeScannersByName(base []Scanner, extras []Scanner) []Scanner {
812832
out := make([]Scanner, 0, len(base)+len(extras))
813833
seen := make(map[string]struct{}, len(base)+len(extras))
@@ -911,12 +931,12 @@ func (a *engineScannerAdapter) Scan(ctx engine.ExternalScanContext) engine.Exter
911931
}
912932

913933
publicResult := a.scanner.Scan(publicCtx)
914-
metadata := make(map[string]string, len(publicResult.Metadata))
915-
for k, v := range publicResult.Metadata {
916-
metadata[k] = v
917-
}
918-
if len(metadata) == 0 {
919-
metadata = nil
934+
var metadata map[string]string
935+
if len(publicResult.Metadata) > 0 {
936+
metadata = make(map[string]string, len(publicResult.Metadata))
937+
for k, v := range publicResult.Metadata {
938+
metadata[k] = v
939+
}
920940
}
921941

922942
return engine.ExternalScanResult{

internal/engine/engine.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ func (e *Engine) AssessContext(ctx context.Context, text, sourceURL string) Risk
142142
result := buildResultWithSignalsWithDebiasAndBan(matches, analysisText, normSignals, e.banListCfg, e.cfg.DebiasTriggers != nil && *e.cfg.DebiasTriggers, e.cfg.StrictMode, e.cfg.BlockThreshold)
143143

144144
if len(e.customScanners) > 0 {
145-
result.Layers = append(result.Layers, heuristicLayerResult(result))
145+
heuristics := heuristicLayerResult(result)
146+
if !isEmptyLayerResult(heuristics) {
147+
result.Layers = append(result.Layers, heuristics)
148+
}
146149
fullPipeline := e.cfg.Mode == ModeStrict
147150
customCtx := internalScanContext{
148151
Text: analysisText,

internal/engine/output_engine.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,17 @@ func assessOutput(text, originalPrompt string, cfg Config, customScanners ...Lay
8080

8181
if len(customScanners) > 0 {
8282
layers = make([]LayerResult, 0, 1+len(scannerLayerExecutionOrder))
83-
layers = append(layers, LayerResult{
83+
heuristics := LayerResult{
8484
Layer: ScannerLayerHeuristics,
8585
Score: score,
8686
ScannersRun: 1,
8787
Matched: score > 0,
8888
Categories: append([]string(nil), categories...),
8989
Patterns: append([]string(nil), patterns...),
90-
})
90+
}
91+
if !isEmptyLayerResult(heuristics) {
92+
layers = append(layers, heuristics)
93+
}
9194

9295
fullPipeline := cfg.Mode == ModeStrict
9396
customCtx := internalScanContext{

shield_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,32 @@ func TestGlobalRegisterScanner_AvailableToNewShield(t *testing.T) {
504504
}
505505
}
506506

507+
func TestWithScanners_ReturnsCloneWithoutMutatingOriginal(t *testing.T) {
508+
shield := mustNewShield(t, Config{Mode: ModeBalanced})
509+
shield.RegisterScanner(&apiKeywordScanner{
510+
name: "clone-only-risk",
511+
trigger: "clone-trigger",
512+
score: 9,
513+
category: "clone-only",
514+
reason: "clone scanner matched",
515+
})
516+
517+
cloned := shield.WithScanners("clone-only-risk")
518+
if cloned == shield {
519+
t.Fatal("expected WithScanners to return a cloned shield instance")
520+
}
521+
522+
original := shield.Assess("contains clone-trigger", "")
523+
if containsString(original.Categories, "clone-only") {
524+
t.Fatalf("original shield should not be mutated, got categories=%v", original.Categories)
525+
}
526+
527+
updated := cloned.Assess("contains clone-trigger", "")
528+
if !containsString(updated.Categories, "clone-only") {
529+
t.Fatalf("cloned shield should include selected scanner, got categories=%v", updated.Categories)
530+
}
531+
}
532+
507533
func containsString(values []string, needle string) bool {
508534
for _, value := range values {
509535
if value == needle {

0 commit comments

Comments
 (0)