Skip to content

Commit c6c52c4

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Setting up data structures for pause/resume/rewind
PiperOrigin-RevId: 861740939
1 parent 1c1698d commit c6c52c4

File tree

8 files changed

+285
-77
lines changed

8 files changed

+285
-77
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.agents;
18+
19+
import com.google.adk.JsonBaseModel;
20+
21+
/** Base class for all agent states. */
22+
public class BaseAgentState extends JsonBaseModel {
23+
24+
protected BaseAgentState() {}
25+
26+
/** Returns a new {@link Builder} for creating {@link BaseAgentState} instances. */
27+
public static Builder builder() {
28+
return new Builder();
29+
}
30+
31+
/** Builder for {@link BaseAgentState}. */
32+
public static class Builder {
33+
private Builder() {}
34+
35+
public BaseAgentState build() {
36+
return new BaseAgentState();
37+
}
38+
}
39+
}

core/src/main/java/com/google/adk/agents/InvocationContext.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ public class InvocationContext {
5050
private final Session session;
5151
private final Optional<Content> userContent;
5252
private final RunConfig runConfig;
53+
private final Map<String, BaseAgentState> agentStates;
54+
private final Map<String, Boolean> endOfAgents;
5355
private final ResumabilityConfig resumabilityConfig;
5456
private final InvocationCostManager invocationCostManager;
5557

@@ -71,6 +73,8 @@ protected InvocationContext(Builder builder) {
7173
this.userContent = builder.userContent;
7274
this.runConfig = builder.runConfig;
7375
this.endInvocation = builder.endInvocation;
76+
this.agentStates = builder.agentStates;
77+
this.endOfAgents = builder.endOfAgents;
7478
this.resumabilityConfig = builder.resumabilityConfig;
7579
this.invocationCostManager = builder.invocationCostManager;
7680
}
@@ -299,6 +303,16 @@ public RunConfig runConfig() {
299303
return runConfig;
300304
}
301305

306+
/** Returns agent-specific state saved within this invocation. */
307+
public Map<String, BaseAgentState> agentStates() {
308+
return agentStates;
309+
}
310+
311+
/** Returns map of agents that ended during this invocation. */
312+
public Map<String, Boolean> endOfAgents() {
313+
return endOfAgents;
314+
}
315+
302316
/**
303317
* Returns whether this invocation should be ended, e.g., due to reaching a terminal state or
304318
* error.
@@ -410,6 +424,8 @@ private Builder(InvocationContext context) {
410424
this.userContent = context.userContent;
411425
this.runConfig = context.runConfig;
412426
this.endInvocation = context.endInvocation;
427+
this.agentStates = new ConcurrentHashMap<>(context.agentStates);
428+
this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents);
413429
this.resumabilityConfig = context.resumabilityConfig;
414430
this.invocationCostManager = context.invocationCostManager;
415431
}
@@ -427,6 +443,8 @@ private Builder(InvocationContext context) {
427443
private Optional<Content> userContent = Optional.empty();
428444
private RunConfig runConfig = RunConfig.builder().build();
429445
private boolean endInvocation = false;
446+
private Map<String, BaseAgentState> agentStates = new ConcurrentHashMap<>();
447+
private Map<String, Boolean> endOfAgents = new ConcurrentHashMap<>();
430448
private ResumabilityConfig resumabilityConfig = new ResumabilityConfig();
431449
private InvocationCostManager invocationCostManager = new InvocationCostManager();
432450

@@ -616,6 +634,30 @@ public Builder endInvocation(boolean endInvocation) {
616634
return this;
617635
}
618636

637+
/**
638+
* Sets agent-specific state saved within this invocation.
639+
*
640+
* @param agentStates agent-specific state saved within this invocation.
641+
* @return this builder instance for chaining.
642+
*/
643+
@CanIgnoreReturnValue
644+
public Builder agentStates(Map<String, BaseAgentState> agentStates) {
645+
this.agentStates = agentStates;
646+
return this;
647+
}
648+
649+
/**
650+
* Sets agent end-of-invocation status.
651+
*
652+
* @param endOfAgents agent end-of-invocation status.
653+
* @return this builder instance for chaining.
654+
*/
655+
@CanIgnoreReturnValue
656+
public Builder endOfAgents(Map<String, Boolean> endOfAgents) {
657+
this.endOfAgents = endOfAgents;
658+
return this;
659+
}
660+
619661
/**
620662
* Sets the resumability configuration for the current agent run.
621663
*
@@ -660,6 +702,8 @@ public boolean equals(Object o) {
660702
&& Objects.equals(session, that.session)
661703
&& Objects.equals(userContent, that.userContent)
662704
&& Objects.equals(runConfig, that.runConfig)
705+
&& Objects.equals(agentStates, that.agentStates)
706+
&& Objects.equals(endOfAgents, that.endOfAgents)
663707
&& Objects.equals(resumabilityConfig, that.resumabilityConfig)
664708
&& Objects.equals(invocationCostManager, that.invocationCostManager);
665709
}
@@ -680,6 +724,8 @@ public int hashCode() {
680724
userContent,
681725
runConfig,
682726
endInvocation,
727+
agentStates,
728+
endOfAgents,
683729
resumabilityConfig,
684730
invocationCostManager);
685731
}

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package com.google.adk.agents;
1818

1919
import static com.google.common.collect.ImmutableList.toImmutableList;
20-
import static java.util.Objects.requireNonNullElse;
2120
import static java.util.stream.Collectors.joining;
2221

2322
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -104,12 +103,12 @@ public enum IncludeContents {
104103
private final Optional<Integer> maxSteps;
105104
private final boolean disallowTransferToParent;
106105
private final boolean disallowTransferToPeers;
107-
private final ImmutableList<? extends BeforeModelCallback> beforeModelCallback;
108-
private final ImmutableList<? extends AfterModelCallback> afterModelCallback;
109-
private final ImmutableList<? extends OnModelErrorCallback> onModelErrorCallback;
110-
private final ImmutableList<? extends BeforeToolCallback> beforeToolCallback;
111-
private final ImmutableList<? extends AfterToolCallback> afterToolCallback;
112-
private final ImmutableList<? extends OnToolErrorCallback> onToolErrorCallback;
106+
private final Optional<List<? extends BeforeModelCallback>> beforeModelCallback;
107+
private final Optional<List<? extends AfterModelCallback>> afterModelCallback;
108+
private final Optional<List<? extends OnModelErrorCallback>> onModelErrorCallback;
109+
private final Optional<List<? extends BeforeToolCallback>> beforeToolCallback;
110+
private final Optional<List<? extends AfterToolCallback>> afterToolCallback;
111+
private final Optional<List<? extends OnToolErrorCallback>> onToolErrorCallback;
113112
private final Optional<Schema> inputSchema;
114113
private final Optional<Schema> outputSchema;
115114
private final Optional<Executor> executor;
@@ -127,28 +126,29 @@ protected LlmAgent(Builder builder) {
127126
builder.beforeAgentCallback,
128127
builder.afterAgentCallback);
129128
this.model = Optional.ofNullable(builder.model);
130-
this.instruction = requireNonNullElse(builder.instruction, new Instruction.Static(""));
129+
this.instruction =
130+
builder.instruction == null ? new Instruction.Static("") : builder.instruction;
131131
this.globalInstruction =
132-
requireNonNullElse(builder.globalInstruction, new Instruction.Static(""));
132+
builder.globalInstruction == null ? new Instruction.Static("") : builder.globalInstruction;
133133
this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig);
134134
this.exampleProvider = Optional.ofNullable(builder.exampleProvider);
135-
this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT);
135+
this.includeContents =
136+
builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT;
136137
this.planning = builder.planning != null && builder.planning;
137138
this.maxSteps = Optional.ofNullable(builder.maxSteps);
138139
this.disallowTransferToParent = builder.disallowTransferToParent;
139140
this.disallowTransferToPeers = builder.disallowTransferToPeers;
140-
this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of());
141-
this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of());
142-
this.onModelErrorCallback =
143-
requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of());
144-
this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of());
145-
this.afterToolCallback = requireNonNullElse(builder.afterToolCallback, ImmutableList.of());
146-
this.onToolErrorCallback = requireNonNullElse(builder.onToolErrorCallback, ImmutableList.of());
141+
this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback);
142+
this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback);
143+
this.onModelErrorCallback = Optional.ofNullable(builder.onModelErrorCallback);
144+
this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback);
145+
this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback);
146+
this.onToolErrorCallback = Optional.ofNullable(builder.onToolErrorCallback);
147147
this.inputSchema = Optional.ofNullable(builder.inputSchema);
148148
this.outputSchema = Optional.ofNullable(builder.outputSchema);
149149
this.executor = Optional.ofNullable(builder.executor);
150150
this.outputKey = Optional.ofNullable(builder.outputKey);
151-
this.toolsUnion = requireNonNullElse(builder.toolsUnion, ImmutableList.of());
151+
this.toolsUnion = builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of();
152152
this.toolsets = extractToolsets(this.toolsUnion);
153153
this.codeExecutor = Optional.ofNullable(builder.codeExecutor);
154154

@@ -704,16 +704,7 @@ private static boolean isThought(Part part) {
704704

705705
@Override
706706
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
707-
return llmFlow
708-
.run(invocationContext)
709-
.concatMap(
710-
event -> {
711-
this.maybeSaveOutputToState(event);
712-
if (invocationContext.shouldPauseInvocation(event)) {
713-
return Flowable.just(event).concatWith(Flowable.empty());
714-
}
715-
return Flowable.just(event);
716-
});
707+
return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState);
717708
}
718709

719710
@Override
@@ -850,27 +841,27 @@ public boolean disallowTransferToPeers() {
850841
return disallowTransferToPeers;
851842
}
852843

853-
public List<? extends BeforeModelCallback> beforeModelCallback() {
844+
public Optional<List<? extends BeforeModelCallback>> beforeModelCallback() {
854845
return beforeModelCallback;
855846
}
856847

857-
public List<? extends AfterModelCallback> afterModelCallback() {
848+
public Optional<List<? extends AfterModelCallback>> afterModelCallback() {
858849
return afterModelCallback;
859850
}
860851

861-
public List<? extends BeforeToolCallback> beforeToolCallback() {
852+
public Optional<List<? extends BeforeToolCallback>> beforeToolCallback() {
862853
return beforeToolCallback;
863854
}
864855

865-
public List<? extends AfterToolCallback> afterToolCallback() {
856+
public Optional<List<? extends AfterToolCallback>> afterToolCallback() {
866857
return afterToolCallback;
867858
}
868859

869-
public List<? extends OnModelErrorCallback> onModelErrorCallback() {
860+
public Optional<List<? extends OnModelErrorCallback>> onModelErrorCallback() {
870861
return onModelErrorCallback;
871862
}
872863

873-
public List<? extends OnToolErrorCallback> onToolErrorCallback() {
864+
public Optional<List<? extends OnToolErrorCallback>> onToolErrorCallback() {
874865
return onToolErrorCallback;
875866
}
876867

@@ -880,7 +871,7 @@ public List<? extends OnToolErrorCallback> onToolErrorCallback() {
880871
* <p>This method is only for use by Agent Development Kit.
881872
*/
882873
public List<? extends BeforeModelCallback> canonicalBeforeModelCallbacks() {
883-
return beforeModelCallback;
874+
return beforeModelCallback.orElse(ImmutableList.of());
884875
}
885876

886877
/**
@@ -889,7 +880,7 @@ public List<? extends BeforeModelCallback> canonicalBeforeModelCallbacks() {
889880
* <p>This method is only for use by Agent Development Kit.
890881
*/
891882
public List<? extends AfterModelCallback> canonicalAfterModelCallbacks() {
892-
return afterModelCallback;
883+
return afterModelCallback.orElse(ImmutableList.of());
893884
}
894885

895886
/**
@@ -898,7 +889,7 @@ public List<? extends AfterModelCallback> canonicalAfterModelCallbacks() {
898889
* <p>This method is only for use by Agent Development Kit.
899890
*/
900891
public List<? extends OnModelErrorCallback> canonicalOnModelErrorCallbacks() {
901-
return onModelErrorCallback;
892+
return onModelErrorCallback.orElse(ImmutableList.of());
902893
}
903894

904895
/**
@@ -907,7 +898,7 @@ public List<? extends OnModelErrorCallback> canonicalOnModelErrorCallbacks() {
907898
* <p>This method is only for use by Agent Development Kit.
908899
*/
909900
public List<? extends BeforeToolCallback> canonicalBeforeToolCallbacks() {
910-
return beforeToolCallback;
901+
return beforeToolCallback.orElse(ImmutableList.of());
911902
}
912903

913904
/**
@@ -916,7 +907,7 @@ public List<? extends BeforeToolCallback> canonicalBeforeToolCallbacks() {
916907
* <p>This method is only for use by Agent Development Kit.
917908
*/
918909
public List<? extends AfterToolCallback> canonicalAfterToolCallbacks() {
919-
return afterToolCallback;
910+
return afterToolCallback.orElse(ImmutableList.of());
920911
}
921912

922913
/**
@@ -925,7 +916,7 @@ public List<? extends AfterToolCallback> canonicalAfterToolCallbacks() {
925916
* <p>This method is only for use by Agent Development Kit.
926917
*/
927918
public List<? extends OnToolErrorCallback> canonicalOnToolErrorCallbacks() {
928-
return onToolErrorCallback;
919+
return onToolErrorCallback.orElse(ImmutableList.of());
929920
}
930921

931922
public Optional<Schema> inputSchema() {

0 commit comments

Comments
 (0)