Skip to content

Commit 4df9609

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: new ContextFilterPlugin
Introduces a new `ContextFilterPlugin` in Java, which is designed to reduce the size of the LLM context by filtering out older turns. Here's a breakdown: * **`ContextFilterPlugin.java`**: This new plugin extends `BasePlugin` and implements logic in its `beforeModelCallback` method. It trims the `LlmRequest` contents to retain only a specified number of the most recent "model" turns and their preceding "user" turns. A key aspect is the `adjustSplitIndexToAvoidOrphanedFunctionResponses` method, which ensures that any `FunctionResponse` included in the filtered context has its corresponding `FunctionCall` also present, preventing invalid inputs to the LLM. * **`ContextFilterPluginTest.java`**: This file contains unit tests for `ContextFilterPlugin`, verifying its behavior for both scenarios where context filtering is not needed and where it successfully reduces the `LlmRequest` contents based on the configured number of invocations to keep. PiperOrigin-RevId: 859380380
1 parent 5ba63f4 commit 4df9609

File tree

2 files changed

+527
-0
lines changed

2 files changed

+527
-0
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.adk.plugins;
17+
18+
import static com.google.common.base.Preconditions.checkArgument;
19+
20+
import com.google.adk.agents.CallbackContext;
21+
import com.google.adk.models.LlmRequest;
22+
import com.google.adk.models.LlmResponse;
23+
import com.google.errorprone.annotations.CanIgnoreReturnValue;
24+
import com.google.genai.types.Content;
25+
import com.google.genai.types.FunctionCall;
26+
import com.google.genai.types.FunctionResponse;
27+
import com.google.genai.types.Part;
28+
import io.reactivex.rxjava3.core.Maybe;
29+
import java.util.ArrayList;
30+
import java.util.HashSet;
31+
import java.util.List;
32+
import java.util.Optional;
33+
import java.util.Set;
34+
import java.util.function.UnaryOperator;
35+
import org.slf4j.Logger;
36+
import org.slf4j.LoggerFactory;
37+
38+
/**
39+
* A plugin that filters the LLM request {@link Content} list to reduce its size, for example to
40+
* adhere to context window limits.
41+
*
42+
* <p>This plugin can be configured to trim the conversation history based on one or both of the
43+
* following criteria:
44+
*
45+
* <ul>
46+
* <li><b>{@code numInvocationsToKeep(N)}:</b> Retains only the last {@code N} model turns and any
47+
* preceding user turns. If multiple user messages appear consecutively before a model
48+
* message, all of them are kept as part of that model invocation window.
49+
* <li><b>{@code customFilter()}:</b> Applies a custom {@link UnaryOperator} to filter the list of
50+
* {@link Content} objects. If {@code numInvocationsToKeep} is also specified, the custom
51+
* filter is applied <i>after</i> the invocation-based trimming occurs.
52+
* </ul>
53+
*
54+
* <p><b>Function Call Handling:</b> The plugin ensures that if a {@link FunctionResponse} is
55+
* included in the filtered list, its corresponding {@link FunctionCall} is also included. If
56+
* filtering would otherwise exclude the {@link FunctionCall}, the window is automatically expanded
57+
* to include it, preventing orphaned function responses.
58+
*
59+
* <p>If no filtering options are provided, this plugin has no effect. If the {@code customFilter}
60+
* throws an exception during execution, filtering is aborted, and the {@link LlmRequest} is not
61+
* modified.
62+
*/
63+
public class ContextFilterPlugin extends BasePlugin {
64+
private static final Logger logger = LoggerFactory.getLogger(ContextFilterPlugin.class);
65+
private static final String MODEL_ROLE = "model";
66+
private static final String USER_ROLE = "user";
67+
68+
private final Optional<Integer> numInvocationsToKeep;
69+
private final Optional<UnaryOperator<List<Content>>> customFilter;
70+
71+
protected ContextFilterPlugin(Builder builder) {
72+
super(builder.name);
73+
this.numInvocationsToKeep = builder.numInvocationsToKeep;
74+
this.customFilter = builder.customFilter;
75+
}
76+
77+
public static Builder builder() {
78+
return new Builder();
79+
}
80+
81+
/** Builder for {@link ContextFilterPlugin}. */
82+
public static class Builder {
83+
private Optional<Integer> numInvocationsToKeep = Optional.empty();
84+
private Optional<UnaryOperator<List<Content>>> customFilter = Optional.empty();
85+
private String name = "context_filter_plugin";
86+
87+
@CanIgnoreReturnValue
88+
public Builder numInvocationsToKeep(int numInvocationsToKeep) {
89+
checkArgument(numInvocationsToKeep > 0, "numInvocationsToKeep must be positive");
90+
this.numInvocationsToKeep = Optional.of(numInvocationsToKeep);
91+
return this;
92+
}
93+
94+
@CanIgnoreReturnValue
95+
public Builder customFilter(UnaryOperator<List<Content>> customFilter) {
96+
this.customFilter = Optional.of(customFilter);
97+
return this;
98+
}
99+
100+
@CanIgnoreReturnValue
101+
public Builder name(String name) {
102+
this.name = name;
103+
return this;
104+
}
105+
106+
public ContextFilterPlugin build() {
107+
return new ContextFilterPlugin(this);
108+
}
109+
}
110+
111+
/**
112+
* Filters the LLM request context by trimming recent turns and applying any custom filter.
113+
*
114+
* <p>If {@code numInvocationsToKeep} is set, this method retains only the most recent model turns
115+
* and their preceding user turns. It ensures that function calls and responses remain paired. If
116+
* a {@code customFilter} is provided, it is applied to the list after trimming.
117+
*
118+
* @param callbackContext The context of the callback.
119+
* @param llmRequest The request builder whose contents will be updated in place.
120+
* @return {@link Maybe#empty()} as this plugin only modifies the request builder.
121+
*/
122+
@Override
123+
public Maybe<LlmResponse> beforeModelCallback(
124+
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
125+
try {
126+
List<Content> contents = llmRequest.build().contents();
127+
if (contents == null || contents.isEmpty()) {
128+
return Maybe.empty();
129+
}
130+
131+
List<Content> effectiveContents = new ArrayList<>(contents);
132+
133+
if (numInvocationsToKeep.isPresent()) {
134+
effectiveContents =
135+
trimContentsByInvocations(numInvocationsToKeep.get(), effectiveContents);
136+
}
137+
138+
if (customFilter.isPresent()) {
139+
effectiveContents = customFilter.get().apply(effectiveContents);
140+
}
141+
142+
llmRequest.contents(effectiveContents);
143+
} catch (RuntimeException e) {
144+
logger.error("Failed to reduce context for request", e);
145+
}
146+
147+
return Maybe.empty();
148+
}
149+
150+
private List<Content> trimContentsByInvocations(int numInvocations, List<Content> contents) {
151+
// If the number of model turns is within limits, no trimming is necessary.
152+
long modelTurnCount =
153+
contents.stream().filter(c -> hasRole(c, MODEL_ROLE)).limit(numInvocations + 1).count();
154+
if (modelTurnCount < numInvocations + 1) {
155+
return contents;
156+
}
157+
int candidateSplitIndex = findNthModelTurnStartIndex(numInvocations, contents);
158+
// Ensure that if a function response is kept, its corresponding function call is also kept.
159+
int finalSplitIndex = adjustIndexForToolCalls(candidateSplitIndex, contents);
160+
// The Nth model turn can be preceded by user turns; expand window to include them.
161+
while (finalSplitIndex > 0
162+
&& hasRole(contents.get(finalSplitIndex - 1), USER_ROLE)
163+
&& !isFunctionResponse(contents.get(finalSplitIndex - 1))) {
164+
finalSplitIndex--;
165+
}
166+
return new ArrayList<>(contents.subList(finalSplitIndex, contents.size()));
167+
}
168+
169+
private int findNthModelTurnStartIndex(int numInvocations, List<Content> contents) {
170+
int modelTurnsToFind = numInvocations;
171+
for (int i = contents.size() - 1; i >= 0; i--) {
172+
if (hasRole(contents.get(i), MODEL_ROLE)) {
173+
modelTurnsToFind--;
174+
if (modelTurnsToFind == 0) {
175+
int startIndex = i;
176+
// Include all preceding user messages in the same turn.
177+
while (startIndex > 0 && hasRole(contents.get(startIndex - 1), USER_ROLE)) {
178+
startIndex--;
179+
}
180+
return startIndex;
181+
}
182+
}
183+
}
184+
return 0;
185+
}
186+
187+
/**
188+
* Adjusts the split index to ensure that if a {@link FunctionResponse} is included in the trimmed
189+
* list, its corresponding {@link FunctionCall} is also included.
190+
*
191+
* <p>This prevents orphaning function responses by expanding the conversation window backward
192+
* (i.e., reducing {@code splitIndex}) to include the earliest function call corresponding to any
193+
* function response that would otherwise be included.
194+
*
195+
* @param splitIndex The candidate index before which messages might be trimmed.
196+
* @param contents The full list of content messages.
197+
* @return An adjusted split index, guaranteed to be less than or equal to {@code splitIndex}.
198+
*/
199+
private int adjustIndexForToolCalls(int splitIndex, List<Content> contents) {
200+
Set<String> neededCallIds = new HashSet<>();
201+
int finalSplitIndex = splitIndex;
202+
for (int i = contents.size() - 1; i >= 0; i--) {
203+
Optional<List<Part>> partsOptional = contents.get(i).parts();
204+
if (partsOptional.isPresent()) {
205+
for (Part part : partsOptional.get()) {
206+
part.functionResponse().flatMap(FunctionResponse::id).ifPresent(neededCallIds::add);
207+
part.functionCall().flatMap(FunctionCall::id).ifPresent(neededCallIds::remove);
208+
}
209+
}
210+
if (i <= finalSplitIndex && neededCallIds.isEmpty()) {
211+
finalSplitIndex = i;
212+
break;
213+
} else if (i == 0) {
214+
finalSplitIndex = 0;
215+
}
216+
}
217+
return finalSplitIndex;
218+
}
219+
220+
private boolean isFunctionResponse(Content content) {
221+
return content
222+
.parts()
223+
.map(parts -> parts.stream().anyMatch(p -> p.functionResponse().isPresent()))
224+
.orElse(false);
225+
}
226+
227+
private boolean hasRole(Content content, String role) {
228+
return content.role().map(r -> r.equals(role)).orElse(false);
229+
}
230+
}

0 commit comments

Comments
 (0)