Skip to content

Commit 0fda56b

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: BaseAgent: Unwrap List from Optional, improve test coverage
PiperOrigin-RevId: 860210309
1 parent 2f5769d commit 0fda56b

File tree

4 files changed

+74
-34
lines changed

4 files changed

+74
-34
lines changed

core/src/main/java/com/google/adk/agents/BaseAgent.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ public abstract class BaseAgent {
5959

6060
private final List<? extends BaseAgent> subAgents;
6161

62-
private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
63-
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;
62+
private final List<? extends BeforeAgentCallback> beforeAgentCallback;
63+
private final List<? extends AfterAgentCallback> afterAgentCallback;
6464

6565
/**
6666
* Creates a new BaseAgent.
@@ -83,8 +83,9 @@ public BaseAgent(
8383
this.description = description;
8484
this.parentAgent = null;
8585
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
86-
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
87-
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);
86+
this.beforeAgentCallback =
87+
beforeAgentCallback != null ? beforeAgentCallback : ImmutableList.of();
88+
this.afterAgentCallback = afterAgentCallback != null ? afterAgentCallback : ImmutableList.of();
8889

8990
// Establish parent relationships for all sub-agents if needed.
9091
for (BaseAgent subAgent : this.subAgents) {
@@ -171,11 +172,11 @@ public List<? extends BaseAgent> subAgents() {
171172
return subAgents;
172173
}
173174

174-
public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
175+
public List<? extends BeforeAgentCallback> beforeAgentCallback() {
175176
return beforeAgentCallback;
176177
}
177178

178-
public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
179+
public List<? extends AfterAgentCallback> afterAgentCallback() {
179180
return afterAgentCallback;
180181
}
181182

@@ -185,7 +186,7 @@ public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
185186
* <p>This method is only for use by Agent Development Kit.
186187
*/
187188
public List<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
188-
return beforeAgentCallback.orElse(ImmutableList.of());
189+
return beforeAgentCallback;
189190
}
190191

191192
/**
@@ -194,7 +195,7 @@ public List<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
194195
* <p>This method is only for use by Agent Development Kit.
195196
*/
196197
public List<? extends AfterAgentCallback> canonicalAfterAgentCallbacks() {
197-
return afterAgentCallback.orElse(ImmutableList.of());
198+
return afterAgentCallback;
198199
}
199200

200201
/**
@@ -239,8 +240,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
239240
() ->
240241
callCallback(
241242
beforeCallbacksToFunctions(
242-
invocationContext.pluginManager(),
243-
beforeAgentCallback.orElse(ImmutableList.of())),
243+
invocationContext.pluginManager(), beforeAgentCallback),
244244
invocationContext)
245245
.flatMapPublisher(
246246
beforeEventOpt -> {
@@ -257,7 +257,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
257257
callCallback(
258258
afterCallbacksToFunctions(
259259
invocationContext.pluginManager(),
260-
afterAgentCallback.orElse(ImmutableList.of())),
260+
afterAgentCallback),
261261
invocationContext)
262262
.flatMapPublisher(Flowable::fromOptional));
263263

core/src/test/java/com/google/adk/agents/BaseAgentTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,4 +316,24 @@ public void canonicalCallbacks_returnsListWhenPresent() {
316316
assertThat(agent.canonicalBeforeAgentCallbacks()).containsExactly(bc);
317317
assertThat(agent.canonicalAfterAgentCallbacks()).containsExactly(ac);
318318
}
319+
320+
@Test
321+
public void runLive_invokesRunLiveImpl() {
322+
var runLiveCallback = TestCallback.<Void>returningEmpty();
323+
Content runLiveImplContent = Content.fromParts(Part.fromText("live_output"));
324+
TestBaseAgent agent =
325+
new TestBaseAgent(
326+
TEST_AGENT_NAME,
327+
TEST_AGENT_DESCRIPTION,
328+
/* beforeAgentCallbacks= */ ImmutableList.of(),
329+
/* afterAgentCallbacks= */ ImmutableList.of(),
330+
runLiveCallback.asRunLiveImplSupplier(runLiveImplContent));
331+
InvocationContext invocationContext = TestUtils.createInvocationContext(agent);
332+
333+
List<Event> results = agent.runLive(invocationContext).toList().blockingGet();
334+
335+
assertThat(results).hasSize(1);
336+
assertThat(results.get(0).content()).hasValue(runLiveImplContent);
337+
assertThat(runLiveCallback.wasCalled()).isTrue();
338+
}
319339
}

core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,20 +1161,25 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks()
11611161

11621162
String pfx = "test.callbacks.";
11631163
registry.register(
1164-
pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty());
1164+
pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty());
11651165
registry.register(
1166-
pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty());
1167-
registry.register(pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (ctx) -> Maybe.empty());
1166+
pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty());
11681167
registry.register(
1169-
pfx + "before_model_1", (Callbacks.BeforeModelCallback) (ctx, req) -> Maybe.empty());
1168+
pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (unusedCtx) -> Maybe.empty());
11701169
registry.register(
1171-
pfx + "after_model_1", (Callbacks.AfterModelCallback) (ctx, resp) -> Maybe.empty());
1170+
pfx + "before_model_1",
1171+
(Callbacks.BeforeModelCallback) (unusedCtx, unusedReq) -> Maybe.empty());
1172+
registry.register(
1173+
pfx + "after_model_1",
1174+
(Callbacks.AfterModelCallback) (unusedCtx, unusedResp) -> Maybe.empty());
11721175
registry.register(
11731176
pfx + "before_tool_1",
1174-
(Callbacks.BeforeToolCallback) (inv, tool, args, toolCtx) -> Maybe.empty());
1177+
(Callbacks.BeforeToolCallback)
1178+
(unusedInv, unusedTool, unusedArgs, unusedToolCtx) -> Maybe.empty());
11751179
registry.register(
11761180
pfx + "after_tool_1",
1177-
(Callbacks.AfterToolCallback) (inv, tool, args, toolCtx, resp) -> Maybe.empty());
1181+
(Callbacks.AfterToolCallback)
1182+
(unusedInv, unusedTool, unusedArgs, unusedToolCtx, unusedResp) -> Maybe.empty());
11781183

11791184
File configFile = tempFolder.newFile("with_callbacks.yaml");
11801185
Files.writeString(
@@ -1204,10 +1209,8 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks()
12041209
assertThat(agent).isInstanceOf(LlmAgent.class);
12051210
LlmAgent llm = (LlmAgent) agent;
12061211

1207-
assertThat(agent.beforeAgentCallback()).isPresent();
1208-
assertThat(agent.beforeAgentCallback().get()).hasSize(2);
1209-
assertThat(agent.afterAgentCallback()).isPresent();
1210-
assertThat(agent.afterAgentCallback().get()).hasSize(1);
1212+
assertThat(agent.beforeAgentCallback()).hasSize(2);
1213+
assertThat(agent.afterAgentCallback()).hasSize(1);
12111214

12121215
assertThat(llm.beforeModelCallback()).isPresent();
12131216
assertThat(llm.beforeModelCallback().get()).hasSize(1);

core/src/test/java/com/google/adk/testing/TestCallback.java

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,63 +102,80 @@ public Supplier<Flowable<Event>> asRunAsyncImplSupplier(String contentText) {
102102
return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText)));
103103
}
104104

105+
/**
106+
* Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable}
107+
* with an event containing the given content.
108+
*/
109+
public Supplier<Flowable<Event>> asRunLiveImplSupplier(Content content) {
110+
return () ->
111+
Flowable.defer(
112+
() -> {
113+
markAsCalled();
114+
return Flowable.just(Event.builder().content(content).build());
115+
});
116+
}
117+
105118
@SuppressWarnings("unchecked") // This cast is safe if T is Content.
106119
public BeforeAgentCallback asBeforeAgentCallback() {
107-
return ctx -> (Maybe<Content>) callMaybe();
120+
return (unusedCtx) -> (Maybe<Content>) callMaybe();
108121
}
109122

110123
@SuppressWarnings("unchecked") // This cast is safe if T is Content.
111124
public BeforeAgentCallbackSync asBeforeAgentCallbackSync() {
112-
return ctx -> (Optional<Content>) callOptional();
125+
return (unusedCtx) -> (Optional<Content>) callOptional();
113126
}
114127

115128
@SuppressWarnings("unchecked") // This cast is safe if T is Content.
116129
public AfterAgentCallback asAfterAgentCallback() {
117-
return ctx -> (Maybe<Content>) callMaybe();
130+
return (unusedCtx) -> (Maybe<Content>) callMaybe();
118131
}
119132

120133
@SuppressWarnings("unchecked") // This cast is safe if T is Content.
121134
public AfterAgentCallbackSync asAfterAgentCallbackSync() {
122-
return ctx -> (Optional<Content>) callOptional();
135+
return (unusedCtx) -> (Optional<Content>) callOptional();
123136
}
124137

125138
@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
126139
public BeforeModelCallback asBeforeModelCallback() {
127-
return (ctx, req) -> (Maybe<LlmResponse>) callMaybe();
140+
return (unusedCtx, unusedReq) -> (Maybe<LlmResponse>) callMaybe();
128141
}
129142

130143
@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
131144
public BeforeModelCallbackSync asBeforeModelCallbackSync() {
132-
return (ctx, req) -> (Optional<LlmResponse>) callOptional();
145+
return (unusedCtx, unusedReq) -> (Optional<LlmResponse>) callOptional();
133146
}
134147

135148
@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
136149
public AfterModelCallback asAfterModelCallback() {
137-
return (ctx, res) -> (Maybe<LlmResponse>) callMaybe();
150+
return (unusedCtx, unusedRes) -> (Maybe<LlmResponse>) callMaybe();
138151
}
139152

140153
@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
141154
public AfterModelCallbackSync asAfterModelCallbackSync() {
142-
return (ctx, res) -> (Optional<LlmResponse>) callOptional();
155+
return (unusedCtx, unusedRes) -> (Optional<LlmResponse>) callOptional();
143156
}
144157

145158
@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
146159
public BeforeToolCallback asBeforeToolCallback() {
147-
return (invCtx, tool, toolArgs, toolCtx) -> (Maybe<Map<String, Object>>) callMaybe();
160+
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) ->
161+
(Maybe<Map<String, Object>>) callMaybe();
148162
}
149163

150164
@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
151165
public BeforeToolCallbackSync asBeforeToolCallbackSync() {
152-
return (invCtx, tool, toolArgs, toolCtx) -> (Optional<Map<String, Object>>) callOptional();
166+
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) ->
167+
(Optional<Map<String, Object>>) callOptional();
153168
}
154169

155170
@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
156171
public AfterToolCallback asAfterToolCallback() {
157-
return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe<Map<String, Object>>) callMaybe();
172+
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) ->
173+
(Maybe<Map<String, Object>>) callMaybe();
158174
}
159175

160176
@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
161177
public AfterToolCallbackSync asAfterToolCallbackSync() {
162-
return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional<Map<String, Object>>) callOptional();
178+
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) ->
179+
(Optional<Map<String, Object>>) callOptional();
163180
}
164181
}

0 commit comments

Comments
 (0)