Skip to content

Commit cb08091

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Setting up data structures for pause/resume/rewind
PiperOrigin-RevId: 860852195
1 parent a3f1763 commit cb08091

File tree

6 files changed

+240
-17
lines changed

6 files changed

+240
-17
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.agents;
18+
19+
import com.google.adk.JsonBaseModel;
20+
21+
/** Base class for all agent states. */
22+
public class BaseAgentState extends JsonBaseModel {
23+
24+
protected BaseAgentState() {}
25+
26+
/** Returns a new {@link Builder} for creating {@link BaseAgentState} instances. */
27+
public static Builder builder() {
28+
return new Builder();
29+
}
30+
31+
/** Builder for {@link BaseAgentState}. */
32+
public static class Builder {
33+
private Builder() {}
34+
35+
public BaseAgentState build() {
36+
return new BaseAgentState();
37+
}
38+
}
39+
}

core/src/main/java/com/google/adk/agents/InvocationContext.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ public class InvocationContext {
5050
private final Session session;
5151
private final Optional<Content> userContent;
5252
private final RunConfig runConfig;
53+
private final Map<String, BaseAgentState> agentStates;
54+
private final Map<String, Boolean> endOfAgents;
5355
private final ResumabilityConfig resumabilityConfig;
5456
private final InvocationCostManager invocationCostManager;
5557

@@ -71,6 +73,8 @@ protected InvocationContext(Builder builder) {
7173
this.userContent = builder.userContent;
7274
this.runConfig = builder.runConfig;
7375
this.endInvocation = builder.endInvocation;
76+
this.agentStates = builder.agentStates;
77+
this.endOfAgents = builder.endOfAgents;
7478
this.resumabilityConfig = builder.resumabilityConfig;
7579
this.invocationCostManager = builder.invocationCostManager;
7680
}
@@ -299,6 +303,16 @@ public RunConfig runConfig() {
299303
return runConfig;
300304
}
301305

306+
/** Returns agent-specific state saved within this invocation. */
307+
public Map<String, BaseAgentState> agentStates() {
308+
return agentStates;
309+
}
310+
311+
/** Returns map of agents that ended during this invocation. */
312+
public Map<String, Boolean> endOfAgents() {
313+
return endOfAgents;
314+
}
315+
302316
/**
303317
* Returns whether this invocation should be ended, e.g., due to reaching a terminal state or
304318
* error.
@@ -410,6 +424,8 @@ private Builder(InvocationContext context) {
410424
this.userContent = context.userContent;
411425
this.runConfig = context.runConfig;
412426
this.endInvocation = context.endInvocation;
427+
this.agentStates = new ConcurrentHashMap<>(context.agentStates);
428+
this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents);
413429
this.resumabilityConfig = context.resumabilityConfig;
414430
this.invocationCostManager = context.invocationCostManager;
415431
}
@@ -427,6 +443,8 @@ private Builder(InvocationContext context) {
427443
private Optional<Content> userContent = Optional.empty();
428444
private RunConfig runConfig = RunConfig.builder().build();
429445
private boolean endInvocation = false;
446+
private Map<String, BaseAgentState> agentStates = new ConcurrentHashMap<>();
447+
private Map<String, Boolean> endOfAgents = new ConcurrentHashMap<>();
430448
private ResumabilityConfig resumabilityConfig = new ResumabilityConfig();
431449
private InvocationCostManager invocationCostManager = new InvocationCostManager();
432450

@@ -616,6 +634,30 @@ public Builder endInvocation(boolean endInvocation) {
616634
return this;
617635
}
618636

637+
/**
638+
* Sets agent-specific state saved within this invocation.
639+
*
640+
* @param agentStates agent-specific state saved within this invocation.
641+
* @return this builder instance for chaining.
642+
*/
643+
@CanIgnoreReturnValue
644+
public Builder agentStates(Map<String, BaseAgentState> agentStates) {
645+
this.agentStates = agentStates;
646+
return this;
647+
}
648+
649+
/**
650+
* Sets agent end-of-invocation status.
651+
*
652+
* @param endOfAgents agent end-of-invocation status.
653+
* @return this builder instance for chaining.
654+
*/
655+
@CanIgnoreReturnValue
656+
public Builder endOfAgents(Map<String, Boolean> endOfAgents) {
657+
this.endOfAgents = endOfAgents;
658+
return this;
659+
}
660+
619661
/**
620662
* Sets the resumability configuration for the current agent run.
621663
*
@@ -660,6 +702,8 @@ public boolean equals(Object o) {
660702
&& Objects.equals(session, that.session)
661703
&& Objects.equals(userContent, that.userContent)
662704
&& Objects.equals(runConfig, that.runConfig)
705+
&& Objects.equals(agentStates, that.agentStates)
706+
&& Objects.equals(endOfAgents, that.endOfAgents)
663707
&& Objects.equals(resumabilityConfig, that.resumabilityConfig)
664708
&& Objects.equals(invocationCostManager, that.invocationCostManager);
665709
}
@@ -680,6 +724,8 @@ public int hashCode() {
680724
userContent,
681725
runConfig,
682726
endInvocation,
727+
agentStates,
728+
endOfAgents,
683729
resumabilityConfig,
684730
invocationCostManager);
685731
}

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -704,16 +704,7 @@ private static boolean isThought(Part part) {
704704

705705
@Override
706706
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
707-
return llmFlow
708-
.run(invocationContext)
709-
.concatMap(
710-
event -> {
711-
this.maybeSaveOutputToState(event);
712-
if (invocationContext.shouldPauseInvocation(event)) {
713-
return Flowable.just(event).concatWith(Flowable.empty());
714-
}
715-
return Flowable.just(event);
716-
});
707+
return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState);
717708
}
718709

719710
@Override

core/src/main/java/com/google/adk/events/EventActions.java

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
*/
1616
package com.google.adk.events;
1717

18+
import com.fasterxml.jackson.annotation.JsonInclude;
1819
import com.fasterxml.jackson.annotation.JsonProperty;
1920
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
21+
import com.google.adk.agents.BaseAgentState;
2022
import com.google.errorprone.annotations.CanIgnoreReturnValue;
2123
import com.google.genai.types.Part;
2224
import java.util.Objects;
@@ -37,8 +39,11 @@ public class EventActions {
3739
private Optional<Boolean> escalate;
3840
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
3941
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations;
42+
private boolean endOfAgent;
43+
private ConcurrentMap<String, BaseAgentState> agentState;
4044
private Optional<Boolean> endInvocation;
4145
private Optional<EventCompaction> compaction;
46+
private Optional<String> rewindBeforeInvocationId;
4247

4348
/** Default constructor for Jackson. */
4449
public EventActions() {
@@ -49,8 +54,11 @@ public EventActions() {
4954
this.escalate = Optional.empty();
5055
this.requestedAuthConfigs = new ConcurrentHashMap<>();
5156
this.requestedToolConfirmations = new ConcurrentHashMap<>();
57+
this.endOfAgent = false;
5258
this.endInvocation = Optional.empty();
5359
this.compaction = Optional.empty();
60+
this.agentState = new ConcurrentHashMap<>();
61+
this.rewindBeforeInvocationId = Optional.empty();
5462
}
5563

5664
private EventActions(Builder builder) {
@@ -61,8 +69,11 @@ private EventActions(Builder builder) {
6169
this.escalate = builder.escalate;
6270
this.requestedAuthConfigs = builder.requestedAuthConfigs;
6371
this.requestedToolConfirmations = builder.requestedToolConfirmations;
72+
this.endOfAgent = builder.endOfAgent;
6473
this.endInvocation = builder.endInvocation;
6574
this.compaction = builder.compaction;
75+
this.agentState = builder.agentState;
76+
this.rewindBeforeInvocationId = builder.rewindBeforeInvocationId;
6677
}
6778

6879
@JsonProperty("skipSummarization")
@@ -146,6 +157,16 @@ public void setRequestedToolConfirmations(
146157
this.requestedToolConfirmations = requestedToolConfirmations;
147158
}
148159

160+
@JsonProperty("endOfAgent")
161+
@JsonInclude(JsonInclude.Include.NON_DEFAULT)
162+
public boolean endOfAgent() {
163+
return endOfAgent;
164+
}
165+
166+
public void setEndOfAgent(boolean endOfAgent) {
167+
this.endOfAgent = endOfAgent;
168+
}
169+
149170
@JsonProperty("endInvocation")
150171
public Optional<Boolean> endInvocation() {
151172
return endInvocation;
@@ -168,6 +189,25 @@ public void setCompaction(Optional<EventCompaction> compaction) {
168189
this.compaction = compaction;
169190
}
170191

192+
@JsonProperty("agentState")
193+
@JsonInclude(JsonInclude.Include.NON_EMPTY)
194+
public ConcurrentMap<String, BaseAgentState> agentState() {
195+
return agentState;
196+
}
197+
198+
public void setAgentState(ConcurrentMap<String, BaseAgentState> agentState) {
199+
this.agentState = agentState;
200+
}
201+
202+
@JsonProperty("rewindBeforeInvocationId")
203+
public Optional<String> rewindBeforeInvocationId() {
204+
return rewindBeforeInvocationId;
205+
}
206+
207+
public void setRewindBeforeInvocationId(@Nullable String rewindBeforeInvocationId) {
208+
this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId);
209+
}
210+
171211
public static Builder builder() {
172212
return new Builder();
173213
}
@@ -191,8 +231,11 @@ public boolean equals(Object o) {
191231
&& Objects.equals(escalate, that.escalate)
192232
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs)
193233
&& Objects.equals(requestedToolConfirmations, that.requestedToolConfirmations)
234+
&& (endOfAgent == that.endOfAgent)
194235
&& Objects.equals(endInvocation, that.endInvocation)
195-
&& Objects.equals(compaction, that.compaction);
236+
&& Objects.equals(compaction, that.compaction)
237+
&& Objects.equals(agentState, that.agentState)
238+
&& Objects.equals(rewindBeforeInvocationId, that.rewindBeforeInvocationId);
196239
}
197240

198241
@Override
@@ -205,8 +248,11 @@ public int hashCode() {
205248
escalate,
206249
requestedAuthConfigs,
207250
requestedToolConfirmations,
251+
endOfAgent,
208252
endInvocation,
209-
compaction);
253+
compaction,
254+
agentState,
255+
rewindBeforeInvocationId);
210256
}
211257

212258
/** Builder for {@link EventActions}. */
@@ -218,8 +264,11 @@ public static class Builder {
218264
private Optional<Boolean> escalate;
219265
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
220266
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations;
267+
private boolean endOfAgent = false;
221268
private Optional<Boolean> endInvocation;
222269
private Optional<EventCompaction> compaction;
270+
private ConcurrentMap<String, BaseAgentState> agentState;
271+
private Optional<String> rewindBeforeInvocationId;
223272

224273
public Builder() {
225274
this.skipSummarization = Optional.empty();
@@ -231,6 +280,8 @@ public Builder() {
231280
this.requestedToolConfirmations = new ConcurrentHashMap<>();
232281
this.endInvocation = Optional.empty();
233282
this.compaction = Optional.empty();
283+
this.agentState = new ConcurrentHashMap<>();
284+
this.rewindBeforeInvocationId = Optional.empty();
234285
}
235286

236287
private Builder(EventActions eventActions) {
@@ -242,8 +293,11 @@ private Builder(EventActions eventActions) {
242293
this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs());
243294
this.requestedToolConfirmations =
244295
new ConcurrentHashMap<>(eventActions.requestedToolConfirmations());
296+
this.endOfAgent = eventActions.endOfAgent();
245297
this.endInvocation = eventActions.endInvocation();
246298
this.compaction = eventActions.compaction();
299+
this.agentState = new ConcurrentHashMap<>(eventActions.agentState());
300+
this.rewindBeforeInvocationId = eventActions.rewindBeforeInvocationId();
247301
}
248302

249303
@CanIgnoreReturnValue
@@ -296,6 +350,13 @@ public Builder requestedToolConfirmations(ConcurrentMap<String, ToolConfirmation
296350
return this;
297351
}
298352

353+
@CanIgnoreReturnValue
354+
@JsonProperty("endOfAgent")
355+
public Builder endOfAgent(boolean endOfAgent) {
356+
this.endOfAgent = endOfAgent;
357+
return this;
358+
}
359+
299360
@CanIgnoreReturnValue
300361
@JsonProperty("endInvocation")
301362
public Builder endInvocation(boolean endInvocation) {
@@ -310,6 +371,20 @@ public Builder compaction(EventCompaction value) {
310371
return this;
311372
}
312373

374+
@CanIgnoreReturnValue
375+
@JsonProperty("agentState")
376+
public Builder agentState(ConcurrentMap<String, BaseAgentState> agentState) {
377+
this.agentState = agentState;
378+
return this;
379+
}
380+
381+
@CanIgnoreReturnValue
382+
@JsonProperty("rewindBeforeInvocationId")
383+
public Builder rewindBeforeInvocationId(String rewindBeforeInvocationId) {
384+
this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId);
385+
return this;
386+
}
387+
313388
@CanIgnoreReturnValue
314389
public Builder merge(EventActions other) {
315390
other.skipSummarization().ifPresent(this::skipSummarization);
@@ -319,8 +394,11 @@ public Builder merge(EventActions other) {
319394
other.escalate().ifPresent(this::escalate);
320395
this.requestedAuthConfigs.putAll(other.requestedAuthConfigs());
321396
this.requestedToolConfirmations.putAll(other.requestedToolConfirmations());
397+
this.endOfAgent = other.endOfAgent();
322398
other.endInvocation().ifPresent(this::endInvocation);
323399
other.compaction().ifPresent(this::compaction);
400+
this.agentState.putAll(other.agentState());
401+
other.rewindBeforeInvocationId().ifPresent(this::rewindBeforeInvocationId);
324402
return this;
325403
}
326404

0 commit comments

Comments
 (0)