Skip to content

Commit 27c1744

Browse files
[feature] Tome automation (#1526)
* added new fields for tome automation * Implement Tome automation in ClaimTasks API (#1525) * Implement Tome automation for ClaimTasks Refactored ClaimTasks to include logic for automatically queuing Quests/Tasks based on Tome triggers: - RunOnNewBeaconCallback - RunOnFirstHostCallback - RunOnSchedule (cron) Added Prometheus metrics for automation errors and implemented non-blocking error handling. * Refactor Tome automation into helper function Moved Tome automation logic from `ClaimTasks` to `handleTomeAutomation` to improve code readability and maintainability. * Add comprehensive tests for Tome automation Added `tome_automation_test.go` containing `TestHandleTomeAutomation` to verify: - Triggers: New Beacon, New Host, Schedule. - Constraints: Scheduled host restrictions (allow/deny). - Deduplication: Ensuring a tome is only queued once per callback even if multiple triggers match. This addresses PR feedback requesting unit/integration tests for the new `handleTomeAutomation` helper. * Refactor handleTomeAutomation to use early return Simplified the control flow in `handleTomeAutomation` by returning early on error, reducing nesting and improving readability. --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
1 parent c20e22d commit 27c1744

28 files changed

+2920
-94
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

go.sum

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tavern/internal/c2/api_claim_tasks.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@ import (
99
"time"
1010

1111
"github.com/prometheus/client_golang/prometheus"
12+
"github.com/robfig/cron/v3"
1213
"google.golang.org/grpc/codes"
1314
"google.golang.org/grpc/status"
1415
"realm.pub/tavern/internal/c2/c2pb"
1516
"realm.pub/tavern/internal/c2/epb"
17+
"realm.pub/tavern/internal/ent"
1618
"realm.pub/tavern/internal/ent/beacon"
1719
"realm.pub/tavern/internal/ent/host"
1820
"realm.pub/tavern/internal/ent/tag"
1921
"realm.pub/tavern/internal/ent/task"
22+
"realm.pub/tavern/internal/ent/tome"
2023
"realm.pub/tavern/internal/namegen"
2124
)
2225

@@ -28,10 +31,117 @@ var (
2831
},
2932
[]string{"host_identifier", "host_groups", "host_services"},
3033
)
34+
metricTomeAutomationErrors = prometheus.NewCounter(
35+
prometheus.CounterOpts{
36+
Name: "tavern_tome_automation_errors_total",
37+
Help: "The total number of errors encountered during tome automation",
38+
},
39+
)
3140
)
3241

3342
func init() {
3443
prometheus.MustRegister(metricHostCallbacksTotal)
44+
prometheus.MustRegister(metricTomeAutomationErrors)
45+
}
46+
47+
func (srv *Server) handleTomeAutomation(ctx context.Context, beaconID int, hostID int, isNewBeacon bool, isNewHost bool, now time.Time) {
48+
// Tome Automation Logic
49+
candidateTomes, err := srv.graph.Tome.Query().
50+
Where(tome.Or(
51+
tome.RunOnNewBeaconCallback(true),
52+
tome.RunOnFirstHostCallback(true),
53+
tome.RunOnScheduleNEQ(""),
54+
)).
55+
All(ctx)
56+
57+
if err != nil {
58+
slog.ErrorContext(ctx, "failed to query candidate tomes for automation", "err", err)
59+
metricTomeAutomationErrors.Inc()
60+
return
61+
}
62+
63+
selectedTomes := make(map[int]*ent.Tome)
64+
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
65+
currentMinute := now.Truncate(time.Minute)
66+
67+
for _, t := range candidateTomes {
68+
shouldRun := false
69+
70+
// Check RunOnNewBeaconCallback
71+
if isNewBeacon && t.RunOnNewBeaconCallback {
72+
shouldRun = true
73+
}
74+
75+
// Check RunOnFirstHostCallback
76+
if !shouldRun && isNewHost && t.RunOnFirstHostCallback {
77+
shouldRun = true
78+
}
79+
80+
// Check RunOnSchedule
81+
if !shouldRun && t.RunOnSchedule != "" {
82+
sched, err := parser.Parse(t.RunOnSchedule)
83+
if err == nil {
84+
// Check if schedule matches current time
85+
// Next(now-1sec) == now?
86+
next := sched.Next(currentMinute.Add(-1 * time.Second))
87+
if next.Equal(currentMinute) {
88+
// Check scheduled_hosts constraint
89+
hostCount, err := t.QueryScheduledHosts().Count(ctx)
90+
if err != nil {
91+
slog.ErrorContext(ctx, "failed to count scheduled hosts for automation", "err", err, "tome_id", t.ID)
92+
metricTomeAutomationErrors.Inc()
93+
continue
94+
}
95+
if hostCount == 0 {
96+
shouldRun = true
97+
} else {
98+
hostExists, err := t.QueryScheduledHosts().
99+
Where(host.ID(hostID)).
100+
Exist(ctx)
101+
if err != nil {
102+
slog.ErrorContext(ctx, "failed to check host existence for automation", "err", err, "tome_id", t.ID)
103+
metricTomeAutomationErrors.Inc()
104+
continue
105+
}
106+
if hostExists {
107+
shouldRun = true
108+
}
109+
}
110+
}
111+
} else {
112+
// Don't log cron parse errors for now, as it might be spammy if stored in DB
113+
// metricTomeAutomationErrors.Inc()
114+
}
115+
}
116+
117+
if shouldRun {
118+
selectedTomes[t.ID] = t
119+
}
120+
}
121+
122+
// Create Quest and Task for each selected Tome
123+
for _, t := range selectedTomes {
124+
q, err := srv.graph.Quest.Create().
125+
SetName(fmt.Sprintf("Automated: %s", t.Name)).
126+
SetTome(t).
127+
SetParamDefsAtCreation(t.ParamDefs).
128+
SetEldritchAtCreation(t.Eldritch).
129+
Save(ctx)
130+
if err != nil {
131+
slog.ErrorContext(ctx, "failed to create automated quest", "err", err, "tome_id", t.ID)
132+
metricTomeAutomationErrors.Inc()
133+
continue
134+
}
135+
136+
_, err = srv.graph.Task.Create().
137+
SetQuest(q).
138+
SetBeaconID(beaconID).
139+
Save(ctx)
140+
if err != nil {
141+
slog.ErrorContext(ctx, "failed to create automated task", "err", err, "quest_id", q.ID)
142+
metricTomeAutomationErrors.Inc()
143+
}
144+
}
35145
}
36146

37147
func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) (*c2pb.ClaimTasksResponse, error) {
@@ -61,6 +171,15 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
61171
return nil, status.Errorf(codes.InvalidArgument, "must provide agent identifier")
62172
}
63173

174+
// Check if host is new (before upsert)
175+
hostExists, err := srv.graph.Host.Query().
176+
Where(host.IdentifierEQ(req.Beacon.Host.Identifier)).
177+
Exist(ctx)
178+
if err != nil {
179+
return nil, status.Errorf(codes.Internal, "failed to query host existence: %v", err)
180+
}
181+
isNewHost := !hostExists
182+
64183
// Upsert the host
65184
hostID, err := srv.graph.Host.Create().
66185
SetIdentifier(req.Beacon.Host.Identifier).
@@ -118,6 +237,8 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
118237
if err != nil {
119238
return nil, status.Errorf(codes.Internal, "failed to query beacon entity: %v", err)
120239
}
240+
isNewBeacon := !beaconExists
241+
121242
var beaconNameAddr *string = nil
122243
if !beaconExists {
123244
candidateNames := []string{
@@ -172,6 +293,9 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
172293
return nil, status.Errorf(codes.Internal, "failed to upsert beacon entity: %v", err)
173294
}
174295

296+
// Run Tome Automation (non-blocking, best effort)
297+
srv.handleTomeAutomation(ctx, beaconID, hostID, isNewBeacon, isNewHost, now)
298+
175299
// Load Tasks
176300
tasks, err := srv.graph.Task.Query().
177301
Where(task.And(
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
package c2
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/stretchr/testify/assert"
9+
"realm.pub/tavern/internal/c2/c2pb"
10+
"realm.pub/tavern/internal/ent"
11+
"realm.pub/tavern/internal/ent/enttest"
12+
)
13+
14+
func TestHandleTomeAutomation(t *testing.T) {
15+
ctx := context.Background()
16+
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
17+
defer client.Close()
18+
19+
srv := &Server{graph: client}
20+
now := time.Date(2023, 10, 27, 10, 0, 0, 0, time.UTC)
21+
22+
// Create a dummy host and beacon for testing
23+
h := client.Host.Create().
24+
SetIdentifier("test-host").
25+
SetName("Test Host").
26+
SetPlatform(c2pb.Host_PLATFORM_LINUX).
27+
SaveX(ctx)
28+
b := client.Beacon.Create().
29+
SetIdentifier("test-beacon").
30+
SetHost(h).
31+
SetTransport(c2pb.ActiveTransport_TRANSPORT_HTTP1).
32+
SaveX(ctx)
33+
34+
// 1. Setup Tomes
35+
// T1: New Beacon Only
36+
client.Tome.Create().
37+
SetName("Tome New Beacon").
38+
SetDescription("Test").
39+
SetAuthor("Test Author").
40+
SetEldritch("print('new beacon')").
41+
SetRunOnNewBeaconCallback(true).
42+
SaveX(ctx)
43+
44+
// T2: New Host Only
45+
client.Tome.Create().
46+
SetName("Tome New Host").
47+
SetDescription("Test").
48+
SetAuthor("Test Author").
49+
SetEldritch("print('new host')").
50+
SetRunOnFirstHostCallback(true).
51+
SaveX(ctx)
52+
53+
// T3: Schedule Matching (Every minute)
54+
client.Tome.Create().
55+
SetName("Tome Schedule Match").
56+
SetDescription("Test").
57+
SetAuthor("Test Author").
58+
SetEldritch("print('schedule')").
59+
SetRunOnSchedule("* * * * *").
60+
SaveX(ctx)
61+
62+
// T4: Schedule Matching with Host Restriction (Allowed)
63+
client.Tome.Create().
64+
SetName("Tome Schedule Restricted Allowed").
65+
SetDescription("Test").
66+
SetAuthor("Test Author").
67+
SetEldritch("print('schedule restricted')").
68+
SetRunOnSchedule("* * * * *").
69+
AddScheduledHosts(h).
70+
SaveX(ctx)
71+
72+
// T5: Schedule Matching with Host Restriction (Denied - different host)
73+
otherHost := client.Host.Create().
74+
SetIdentifier("other").
75+
SetPlatform(c2pb.Host_PLATFORM_LINUX).
76+
SaveX(ctx)
77+
78+
client.Tome.Create().
79+
SetName("Tome Schedule Restricted Denied").
80+
SetDescription("Test").
81+
SetAuthor("Test Author").
82+
SetEldritch("print('schedule denied')").
83+
SetRunOnSchedule("* * * * *").
84+
AddScheduledHosts(otherHost).
85+
SaveX(ctx)
86+
87+
tests := []struct {
88+
name string
89+
isNewBeacon bool
90+
isNewHost bool
91+
expectedTomes []string
92+
}{
93+
{
94+
name: "New Beacon Only",
95+
isNewBeacon: true,
96+
isNewHost: false,
97+
expectedTomes: []string{
98+
"Tome New Beacon",
99+
"Tome Schedule Match",
100+
"Tome Schedule Restricted Allowed",
101+
},
102+
},
103+
{
104+
name: "New Host Only",
105+
isNewBeacon: false,
106+
isNewHost: true,
107+
expectedTomes: []string{
108+
"Tome New Host",
109+
"Tome Schedule Match",
110+
"Tome Schedule Restricted Allowed",
111+
},
112+
},
113+
{
114+
name: "Both New",
115+
isNewBeacon: true,
116+
isNewHost: true,
117+
expectedTomes: []string{
118+
"Tome New Beacon",
119+
"Tome New Host",
120+
"Tome Schedule Match",
121+
"Tome Schedule Restricted Allowed",
122+
},
123+
},
124+
{
125+
name: "Neither New",
126+
isNewBeacon: false,
127+
isNewHost: false,
128+
expectedTomes: []string{
129+
"Tome Schedule Match",
130+
"Tome Schedule Restricted Allowed",
131+
},
132+
},
133+
}
134+
135+
for _, tt := range tests {
136+
t.Run(tt.name, func(t *testing.T) {
137+
// Clear existing quests/tasks to ensure clean slate
138+
client.Task.Delete().ExecX(ctx)
139+
client.Quest.Delete().ExecX(ctx)
140+
141+
srv.handleTomeAutomation(ctx, b.ID, h.ID, tt.isNewBeacon, tt.isNewHost, now)
142+
143+
// Verify Tasks
144+
tasks := client.Task.Query().WithQuest(func(q *ent.QuestQuery) {
145+
q.WithTome()
146+
}).AllX(ctx)
147+
148+
var createdTomes []string
149+
for _, t := range tasks {
150+
createdTomes = append(createdTomes, t.Edges.Quest.Edges.Tome.Name)
151+
}
152+
153+
assert.ElementsMatch(t, tt.expectedTomes, createdTomes)
154+
})
155+
}
156+
}
157+
158+
func TestHandleTomeAutomation_Deduplication(t *testing.T) {
159+
ctx := context.Background()
160+
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
161+
defer client.Close()
162+
163+
srv := &Server{graph: client}
164+
now := time.Now()
165+
166+
h := client.Host.Create().
167+
SetIdentifier("test").
168+
SetPlatform(c2pb.Host_PLATFORM_LINUX).
169+
SaveX(ctx)
170+
b := client.Beacon.Create().
171+
SetIdentifier("test").
172+
SetHost(h).
173+
SetTransport(c2pb.ActiveTransport_TRANSPORT_HTTP1).
174+
SaveX(ctx)
175+
176+
// Tome with ALL triggers enabled
177+
client.Tome.Create().
178+
SetName("Super Tome").
179+
SetDescription("Test").
180+
SetAuthor("Test Author").
181+
SetEldritch("print('super')").
182+
SetRunOnNewBeaconCallback(true).
183+
SetRunOnFirstHostCallback(true).
184+
SetRunOnSchedule("* * * * *").
185+
SaveX(ctx)
186+
187+
// Trigger all conditions
188+
srv.handleTomeAutomation(ctx, b.ID, h.ID, true, true, now)
189+
190+
// Should only have 1 task
191+
count := client.Task.Query().CountX(ctx)
192+
assert.Equal(t, 1, count, "Should only create one task despite multiple triggers matching")
193+
}

0 commit comments

Comments
 (0)