Skip to content

Commit cd15be3

Browse files
Merge pull request #17 from PSchmitz-Valckenberg/feat/step-17-repetitions-ci
feat: Step 17 — repetitions and confidence intervals for benchmarks
2 parents a20a27d + 9e69091 commit cd15be3

11 files changed

Lines changed: 284 additions & 50 deletions

File tree

scripts/run_benchmark.sh

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env bash
2-
# run_benchmark.sh — runs a full 4-strategy benchmark campaign against a
2+
# run_benchmark.sh — runs a full 5-strategy benchmark campaign against a
33
# locally running SentinelCore instance and saves the report to results/.
44
#
55
# Prerequisites:
@@ -9,7 +9,12 @@
99
#
1010
# Usage:
1111
# ./scripts/run_benchmark.sh
12-
# ./scripts/run_benchmark.sh --label gemini-2.0-flash
12+
# ./scripts/run_benchmark.sh --label gemini-2.5-flash
13+
# ./scripts/run_benchmark.sh --label gemini-2.5-flash --repetitions 5
14+
#
15+
# --repetitions N How many times each strategy is run (default: 3).
16+
# Mean and stddev are reported per strategy. N=1 gives no
17+
# stddev (marked as null in the JSON report).
1318
#
1419
# Note: the active LLM provider is configured in application-local.yml
1520
# (sentinelcore.llm.provider / sentinelcore.llm.model). --label is a free-form
@@ -20,14 +25,18 @@
2025
set -euo pipefail
2126

2227
BASE_URL="http://localhost:8080"
23-
LABEL="gemini-2.0-flash"
28+
LABEL="gemini-2.5-flash"
29+
REPETITIONS=3
2430
RESULTS_DIR="$(dirname "$0")/../results"
2531

2632
while [[ $# -gt 0 ]]; do
2733
case "$1" in
2834
--label)
2935
[[ $# -ge 2 ]] || { echo "Error: --label requires an argument"; exit 1; }
3036
LABEL="$2"; shift 2 ;;
37+
--repetitions)
38+
[[ $# -ge 2 ]] || { echo "Error: --repetitions requires an argument"; exit 1; }
39+
REPETITIONS="$2"; shift 2 ;;
3140
*) echo "Unknown argument: $1"; exit 1 ;;
3241
esac
3342
done
@@ -42,15 +51,17 @@ OUT_DIR="$RESULTS_DIR/${TIMESTAMP}_${LABEL}"
4251
mkdir -p "$OUT_DIR"
4352

4453
echo "=== SentinelCore Benchmark Campaign ==="
45-
echo "Label: $LABEL"
46-
echo "Output dir: $OUT_DIR"
54+
echo "Label: $LABEL"
55+
echo "Repetitions: $REPETITIONS"
56+
echo "Output dir: $OUT_DIR"
4757
echo ""
4858

4959
# ── 1. Create benchmark ──────────────────────────────────────────────────────
5060
echo "[1/3] Creating benchmark..."
5161
CREATE_RESPONSE=$(curl -sfS -X POST "$BASE_URL/api/benchmarks" \
5262
-H "Content-Type: application/json" \
53-
-d "$(jq -n --arg model "$LABEL" '{model: $model, strategyTypes: ["INPUT_FILTER","INPUT_OUTPUT","PROMPT_HARDENING","RAG_CONTENT_FILTER"]}')")
63+
-d "$(jq -n --arg model "$LABEL" --argjson reps "$REPETITIONS" \
64+
'{model: $model, strategyTypes: ["INPUT_FILTER","INPUT_OUTPUT","PROMPT_HARDENING","RAG_CONTENT_FILTER"], repetitions: $reps}')")
5465

5566
BENCHMARK_ID=$(echo "$CREATE_RESPONSE" | jq -r '.benchmarkId')
5667
echo " Benchmark ID: $BENCHMARK_ID"
@@ -76,16 +87,30 @@ REPORT=$(curl -sfS "$BASE_URL/api/benchmarks/$BENCHMARK_ID/report")
7687
echo "$REPORT" | jq . > "$OUT_DIR/03_report.json"
7788

7889
# ── Summary table ─────────────────────────────────────────────────────────────
90+
REPS=$(echo "$REPORT" | jq '.repetitions')
91+
echo ""
92+
echo "=== Results (N=$REPS repetitions per strategy) ==="
93+
echo "Mean per strategy:"
94+
echo "$REPORT" | jq -r '
95+
["Strategy", "ASR", "FPR", "Refusal", "Latency(ms)"],
96+
(.runs[] | [
97+
.strategyType,
98+
(.aggregated.attackSuccessRateMean | tostring),
99+
(.aggregated.falsePositiveRateMean | tostring),
100+
(.aggregated.refusalRateMean | tostring),
101+
(.aggregated.avgLatencyMsMean | tostring)
102+
]) | @tsv' | column -t
103+
79104
echo ""
80-
echo "=== Results ==="
105+
echo "Stddev per strategy (null = N=1, not computable):"
81106
echo "$REPORT" | jq -r '
82-
["Strategy", "AttackSuccess", "FalsePositive", "Refusal", "AvgLatencyMs"],
107+
["Strategy", "ASR-stddev", "FPR-stddev", "Refusal-stddev", "Latency-stddev(ms)"],
83108
(.runs[] | [
84109
.strategyType,
85-
(.metrics.metrics.attackSuccessRate | tostring),
86-
(.metrics.metrics.falsePositiveRate | tostring),
87-
(.metrics.metrics.refusalRate | tostring),
88-
(.metrics.metrics.avgLatencyMs | tostring)
110+
(.aggregated.attackSuccessRateStddev // "null" | tostring),
111+
(.aggregated.falsePositiveRateStddev // "null" | tostring),
112+
(.aggregated.refusalRateStddev // "null" | tostring),
113+
(.aggregated.avgLatencyMsStddev // "null" | tostring)
89114
]) | @tsv' | column -t
90115

91116
echo ""

src/main/java/com/sentinelcore/controller/BenchmarkController.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ public class BenchmarkController {
2121

2222
@PostMapping
2323
public ResponseEntity<BenchmarkCreateResponse> createBenchmark(@Valid @RequestBody BenchmarkCreateRequest request) {
24-
Benchmark benchmark = benchmarkService.createBenchmark(request.model(), request.strategyTypes());
24+
Benchmark benchmark = benchmarkService.createBenchmark(
25+
request.model(), request.strategyTypes(), request.repetitionsOrDefault());
2526
return ResponseEntity
2627
.status(HttpStatus.CREATED)
2728
.body(new BenchmarkCreateResponse(

src/main/java/com/sentinelcore/domain/entity/Benchmark.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,12 @@ public class Benchmark {
4242
name = "benchmark_runs",
4343
joinColumns = @JoinColumn(name = "benchmark_id")
4444
)
45+
@OrderBy("repetitionIndex ASC")
4546
private List<BenchmarkRun> runs = new ArrayList<>();
4647

48+
@Column(name = "repetitions", nullable = false)
49+
private int repetitions = 1;
50+
4751
@Column(name = "created_at", nullable = false)
4852
private Instant createdAt;
4953

src/main/java/com/sentinelcore/domain/entity/BenchmarkRun.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,7 @@ public class BenchmarkRun {
2323

2424
@Column(name = "run_id", nullable = false)
2525
private String runId;
26+
27+
@Column(name = "repetition_index", nullable = false)
28+
private int repetitionIndex;
2629
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package com.sentinelcore.dto;
2+
3+
public record AggregatedStrategyMetrics(
4+
int repetitions,
5+
double attackSuccessRateMean,
6+
Double attackSuccessRateStddev,
7+
double falsePositiveRateMean,
8+
Double falsePositiveRateStddev,
9+
double refusalRateMean,
10+
Double refusalRateStddev,
11+
double avgLatencyMsMean,
12+
Double avgLatencyMsStddev
13+
) {}
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.sentinelcore.dto;
22

33
import com.sentinelcore.domain.enums.StrategyType;
4+
import jakarta.validation.constraints.Max;
5+
import jakarta.validation.constraints.Min;
46
import jakarta.validation.constraints.NotBlank;
57
import jakarta.validation.constraints.NotEmpty;
68
import jakarta.validation.constraints.NotNull;
@@ -9,5 +11,10 @@
911

1012
public record BenchmarkCreateRequest(
1113
@NotBlank String model,
12-
@NotEmpty List<@NotNull StrategyType> strategyTypes
13-
) {}
14+
@NotEmpty List<@NotNull StrategyType> strategyTypes,
15+
@Min(1) @Max(10) Integer repetitions
16+
) {
17+
public int repetitionsOrDefault() {
18+
return repetitions != null ? repetitions : 1;
19+
}
20+
}

src/main/java/com/sentinelcore/dto/BenchmarkReportResponse.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ public record BenchmarkReportResponse(
66
String benchmarkId,
77
String model,
88
String status,
9+
int repetitions,
910
List<RunComparisonEntry> runs
1011
) {}

src/main/java/com/sentinelcore/dto/RunComparisonEntry.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
import com.sentinelcore.domain.enums.StrategyType;
44

5+
import java.util.List;
6+
57
public record RunComparisonEntry(
6-
String runId,
8+
List<String> runIds,
79
StrategyType strategyType,
810
RunMetricsResponse metrics,
11+
AggregatedStrategyMetrics aggregated,
912
DeltaMetrics deltaToBaseline
1013
) {}

src/main/java/com/sentinelcore/service/BenchmarkService.java

Lines changed: 94 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import com.sentinelcore.domain.enums.BenchmarkStatus;
77
import com.sentinelcore.domain.enums.RunMode;
88
import com.sentinelcore.domain.enums.StrategyType;
9+
import com.sentinelcore.dto.AggregatedStrategyMetrics;
910
import com.sentinelcore.dto.BenchmarkExecutionResponse;
1011
import com.sentinelcore.dto.BenchmarkReportResponse;
1112
import com.sentinelcore.dto.DeltaMetrics;
@@ -20,8 +21,10 @@
2021

2122
import java.time.Instant;
2223
import java.util.ArrayList;
24+
import java.util.LinkedHashMap;
2325
import java.util.LinkedHashSet;
2426
import java.util.List;
27+
import java.util.Map;
2528
import java.util.UUID;
2629

2730
@Slf4j
@@ -34,7 +37,7 @@ public class BenchmarkService {
3437
private final ReportingService reportingService;
3538

3639
@Transactional
37-
public Benchmark createBenchmark(String model, List<StrategyType> strategyTypes) {
40+
public Benchmark createBenchmark(String model, List<StrategyType> strategyTypes, int repetitions) {
3841
LinkedHashSet<StrategyType> deduped = new LinkedHashSet<>();
3942
deduped.add(StrategyType.NONE);
4043
deduped.addAll(strategyTypes);
@@ -44,6 +47,7 @@ public Benchmark createBenchmark(String model, List<StrategyType> strategyTypes)
4447
benchmark.setModel(model);
4548
benchmark.setStrategyTypes(new ArrayList<>(deduped));
4649
benchmark.setRuns(new ArrayList<>());
50+
benchmark.setRepetitions(repetitions);
4751
benchmark.setStatus(BenchmarkStatus.CREATED);
4852
benchmark.setCreatedAt(Instant.now());
4953
return benchmarkRepository.save(benchmark);
@@ -64,15 +68,18 @@ public BenchmarkExecutionResponse executeBenchmark(String benchmarkId) {
6468
benchmarkRepository.saveAndFlush(benchmark);
6569

6670
List<BenchmarkRun> completedRuns = new ArrayList<>();
71+
int repetitions = benchmark.getRepetitions();
6772

6873
try {
6974
for (StrategyType strategyType : benchmark.getStrategyTypes()) {
7075
RunMode mode = (strategyType == StrategyType.NONE) ? RunMode.BASELINE : RunMode.DEFENDED;
71-
EvaluationRun run = runService.createRun(mode, benchmark.getModel(), strategyType);
72-
runService.executeRun(run.getId());
73-
completedRuns.add(new BenchmarkRun(strategyType, run.getId()));
74-
log.info("Benchmark {}: completed run {} with strategy {}",
75-
benchmarkId, run.getId(), strategyType);
76+
for (int rep = 0; rep < repetitions; rep++) {
77+
EvaluationRun run = runService.createRun(mode, benchmark.getModel(), strategyType);
78+
runService.executeRun(run.getId());
79+
completedRuns.add(new BenchmarkRun(strategyType, run.getId(), rep));
80+
log.info("Benchmark {}: completed run {} (strategy={}, rep={}/{})",
81+
benchmarkId, run.getId(), strategyType, rep + 1, repetitions);
82+
}
7683
}
7784
benchmark.setStatus(BenchmarkStatus.COMPLETED);
7885
} catch (RuntimeException ex) {
@@ -98,39 +105,94 @@ public BenchmarkReportResponse getReport(String benchmarkId) {
98105
Benchmark benchmark = benchmarkRepository.findById(benchmarkId)
99106
.orElseThrow(() -> new EntityNotFoundException("Benchmark not found: " + benchmarkId));
100107

101-
List<RunWithMetrics> runMetrics = benchmark.getRuns().stream()
102-
.map(br -> new RunWithMetrics(
103-
br.getRunId(),
104-
br.getStrategyType(),
105-
reportingService.getMetrics(br.getRunId())
106-
))
107-
.toList();
108-
109-
RunMetricsResponse baseline = runMetrics.stream()
110-
.filter(r -> r.strategyType() == StrategyType.NONE)
111-
.map(RunWithMetrics::metrics)
112-
.findFirst()
113-
.orElse(null);
114-
115-
List<RunComparisonEntry> entries = runMetrics.stream()
116-
.map(r -> new RunComparisonEntry(
117-
r.runId(),
118-
r.strategyType(),
119-
r.metrics(),
120-
(baseline != null && r.strategyType() != StrategyType.NONE)
121-
? computeDelta(baseline, r.metrics())
122-
: null
123-
))
124-
.toList();
108+
// Group runs by strategy in repetition_index order (guaranteed by @OrderBy on Benchmark.runs)
109+
Map<StrategyType, List<RunMetricsResponse>> metricsByStrategy = new LinkedHashMap<>();
110+
Map<StrategyType, List<String>> runIdsByStrategy = new LinkedHashMap<>();
111+
// Track the rep-0 run per strategy explicitly so representative is stable
112+
Map<StrategyType, RunMetricsResponse> rep0ByStrategy = new LinkedHashMap<>();
113+
for (BenchmarkRun br : benchmark.getRuns()) {
114+
RunMetricsResponse m = reportingService.getMetrics(br.getRunId());
115+
metricsByStrategy.computeIfAbsent(br.getStrategyType(), k -> new ArrayList<>()).add(m);
116+
runIdsByStrategy.computeIfAbsent(br.getStrategyType(), k -> new ArrayList<>())
117+
.add(br.getRunId());
118+
if (br.getRepetitionIndex() == 0) {
119+
rep0ByStrategy.put(br.getStrategyType(), m);
120+
}
121+
}
122+
123+
// Use rep-0 of NONE as baseline for delta computation
124+
RunMetricsResponse baselineSample = rep0ByStrategy.get(StrategyType.NONE);
125+
126+
List<RunComparisonEntry> entries = new ArrayList<>();
127+
for (Map.Entry<StrategyType, List<RunMetricsResponse>> e : metricsByStrategy.entrySet()) {
128+
StrategyType strategy = e.getKey();
129+
List<RunMetricsResponse> runs = e.getValue();
130+
List<String> runIds = runIdsByStrategy.get(strategy);
131+
132+
// rep-0 as the representative for backwards-compatible single-metrics field
133+
RunMetricsResponse representative = rep0ByStrategy.getOrDefault(strategy, runs.get(0));
134+
AggregatedStrategyMetrics aggregated = aggregate(runs);
135+
DeltaMetrics delta = (baselineSample != null && strategy != StrategyType.NONE)
136+
? computeDelta(baselineSample, representative) : null;
137+
138+
entries.add(new RunComparisonEntry(runIds, strategy, representative, aggregated, delta));
139+
}
125140

126141
return new BenchmarkReportResponse(
127142
benchmarkId,
128143
benchmark.getModel(),
129144
benchmark.getStatus().name(),
145+
benchmark.getRepetitions(),
130146
entries
131147
);
132148
}
133149

150+
// --- Statistics ---
151+
152+
static AggregatedStrategyMetrics aggregate(List<RunMetricsResponse> runs) {
153+
int n = runs.size();
154+
double[] asr = extract(runs, r -> r.metrics().attackSuccessRate());
155+
double[] fpr = extract(runs, r -> r.metrics().falsePositiveRate());
156+
double[] rr = extract(runs, r -> r.metrics().refusalRate());
157+
double[] lat = extract(runs, r -> r.metrics().avgLatencyMs());
158+
return new AggregatedStrategyMetrics(
159+
n,
160+
mean(asr), stddev(asr),
161+
mean(fpr), stddev(fpr),
162+
mean(rr), stddev(rr),
163+
mean(lat), stddev(lat)
164+
);
165+
}
166+
167+
private static double[] extract(List<RunMetricsResponse> runs,
168+
java.util.function.ToDoubleFunction<RunMetricsResponse> fn) {
169+
double[] vals = new double[runs.size()];
170+
for (int i = 0; i < runs.size(); i++) {
171+
vals[i] = fn.applyAsDouble(runs.get(i));
172+
}
173+
return vals;
174+
}
175+
176+
static double mean(double[] values) {
177+
return round3(rawMean(values));
178+
}
179+
180+
// Returns null for N=1 (not computable), otherwise population stddev
181+
static Double stddev(double[] values) {
182+
if (values.length <= 1) return null;
183+
double m = rawMean(values);
184+
double sumSq = 0;
185+
for (double v : values) sumSq += (v - m) * (v - m);
186+
return round3(Math.sqrt(sumSq / values.length));
187+
}
188+
189+
private static double rawMean(double[] values) {
190+
if (values.length == 0) return 0.0;
191+
double sum = 0;
192+
for (double v : values) sum += v;
193+
return sum / values.length;
194+
}
195+
134196
private DeltaMetrics computeDelta(RunMetricsResponse baseline, RunMetricsResponse defended) {
135197
RunMetricsResponse.Metrics b = baseline.metrics();
136198
RunMetricsResponse.Metrics d = defended.metrics();
@@ -142,9 +204,7 @@ private DeltaMetrics computeDelta(RunMetricsResponse baseline, RunMetricsRespons
142204
);
143205
}
144206

145-
private double round3(double value) {
207+
private static double round3(double value) {
146208
return Math.round(value * 1000.0) / 1000.0;
147209
}
148-
149-
private record RunWithMetrics(String runId, StrategyType strategyType, RunMetricsResponse metrics) {}
150-
}
210+
}

0 commit comments

Comments
 (0)