Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ public BaseChatModelSetup(
this.tools = descriptor.getArgument("tools");
}

/**
* Trigger construction for resource objects.
*
* <p>Currently, in cross-language invocation scenarios, constructing resource object within an
* async thread may encounter issues. We resolved this issue by moving the construction of the
* resources object out of the method to be async executed and invoking it in the main thread.
*/
@Override
public void open() throws Exception {
this.getResource.apply(this.connection, ResourceType.CHAT_MODEL_CONNECTION);
}

public abstract Map<String, Object> getParameters();

public ChatMessage chat(List<ChatMessage> messages) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ public BaseEmbeddingModelSetup(
this.model = descriptor.getArgument("model");
}

/**
* Trigger construction for resource objects.
*
* <p>Currently, in cross-language invocation scenarios, constructing resource object within an
* async thread may encounter issues. We resolved this issue by moving the construction of the
* resources object out of the method to be async executed and invoking it in the main thread.
*/
@Override
public void open() {
this.getConnection();
}

public abstract Map<String, Object> getParameters();

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ protected FlinkAgentsMetricGroup getMetricGroup() {
return metricGroup;
}

/** Open the resource. */
public void open() throws Exception {}

/** Close the resource. */
public void close() throws Exception {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ public BaseVectorStore(
this.embeddingModel = descriptor.getArgument("embedding_model");
}

/**
* Trigger construction for resource objects.
*
* <p>Currently, in cross-language invocation scenarios, constructing resource object within an
* async thread may encounter issues. We resolved this issue by moving the construction of the
* resources object out of the method to be async executed and invoking it in the main thread.
*/
@Override
public void open() {
this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL);
}

@Override
public ResourceType getResourceType() {
return ResourceType.VECTOR_STORE;
Expand Down
6 changes: 6 additions & 0 deletions plan/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ under the License.
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-runtime</artifactId>
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-agents-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.*;

import static org.apache.flink.agents.api.agents.Agent.STRUCTURED_OUTPUT;
import static org.apache.flink.agents.plan.actions.Utils.supportAsync;

/** Built-in action for processing chat request and tool call result. */
public class ChatModelAction {
Expand Down Expand Up @@ -199,9 +200,10 @@ public static void chat(
(BaseChatModelSetup) ctx.getResource(model, ResourceType.CHAT_MODEL);

boolean chatAsync = ctx.getConfig().get(AgentExecutionOptions.CHAT_ASYNC);
// TODO: python chat model doesn't support async execution yet, see
// https://github.com/apache/flink-agents/issues/448 for details.
chatAsync = chatAsync && !(chatModel instanceof PythonChatModelSetup);

if ((chatModel instanceof PythonChatModelSetup) && !supportAsync()) {
chatAsync = false;
}

Agent.ErrorHandlingStrategy strategy =
ctx.getConfig().get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

import java.util.List;

import static org.apache.flink.agents.plan.actions.Utils.supportAsync;

/** Built-in action for processing context retrieval requests. */
public class ContextRetrievalAction {

Expand Down Expand Up @@ -60,9 +62,9 @@ public static void processContextRetrievalRequest(Event event, RunnerContext ctx
contextRetrievalRequestEvent.getVectorStore(),
ResourceType.VECTOR_STORE);

// TODO: python vector store doesn't support async execution yet, see
// https://github.com/apache/flink-agents/issues/448 for details.
ragAsync = ragAsync && !(vectorStore instanceof PythonVectorStore);
if ((vectorStore instanceof PythonVectorStore) && !supportAsync()) {
ragAsync = false;
}

final VectorStoreQuery vectorStoreQuery =
new VectorStoreQuery(
Expand Down
56 changes: 56 additions & 0 deletions plan/src/main/java/org/apache/flink/agents/plan/actions/Utils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.agents.plan.actions;

import org.apache.flink.runtime.util.EnvironmentInformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.StringTokenizer;

public class Utils {
private static final Logger LOG = LoggerFactory.getLogger(Utils.class);

/**
* Check whether the current Flink version supports the async execution for cross-language
* resource.
*
* <p>The async execution for java resource is supported only on flink with the pemja 0.6.2
* dependency. See <a href="https://github.com/apache/flink-agents/pull/571">flink-agents</a>
* for details.
*/
public static boolean supportAsync() {
String version = EnvironmentInformation.getVersion();
StringTokenizer st = new StringTokenizer(version, ".");
int major = Integer.parseInt(st.nextToken());
int minor = Integer.parseInt(st.nextToken());
int micro = Integer.parseInt(st.nextToken());

if ((major == 1 && (minor < 20 || micro <= 3))
|| (major == 2 && minor == 0 && micro <= 1)
|| (major == 2 && minor == 1 && micro <= 1)
|| (major == 2 && minor == 2 && micro <= 0)) {
LOG.debug(
"Flink {} doesn't support async execution for java resource, will fallback to sync execution.",
version);
return false;
}

return true;
}
}
37 changes: 23 additions & 14 deletions python/flink_agents/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def metric_group(self) -> "MetricGroup | None":
"""
return self._metric_group

def open(self) -> None:
"""Open the resource."""

def close(self) -> None:
"""Close the resource."""

Expand Down Expand Up @@ -120,14 +123,14 @@ class ResourceDescriptor(BaseModel):
arguments: Dict[str, Any]

def __init__(
self,
/,
*,
clazz: str | None = None,
target_module: str | None = None,
target_clazz: str | None = None,
arguments: Dict[str, Any] | None = None,
**kwargs: Any,
self,
/,
*,
clazz: str | None = None,
target_module: str | None = None,
target_clazz: str | None = None,
arguments: Dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize ResourceDescriptor.

Expand Down Expand Up @@ -182,9 +185,9 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, ResourceDescriptor):
return False
return (
self.target_module == other.target_module
and self.target_clazz == other.target_clazz
and self.arguments == other.arguments
self.target_module == other.target_module
and self.target_clazz == other.target_clazz
and self.arguments == other.arguments
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -253,8 +256,12 @@ class ChatModel:
TONGYI_SETUP = "flink_agents.integrations.chat_models.tongyi_chat_model.TongyiChatModelSetup"

# Java Wrapper
JAVA_WRAPPER_CONNECTION = "flink_agents.api.chat_models.java_chat_model.JavaChatModelConnection"
JAVA_WRAPPER_SETUP = "flink_agents.api.chat_models.java_chat_model.JavaChatModelSetup"
JAVA_WRAPPER_CONNECTION = (
"flink_agents.api.chat_models.java_chat_model.JavaChatModelConnection"
)
JAVA_WRAPPER_SETUP = (
"flink_agents.api.chat_models.java_chat_model.JavaChatModelSetup"
)

class Java:
"""Java implementations of ChatModel."""
Expand Down Expand Up @@ -307,7 +314,9 @@ class VectorStore:
CHROMA_VECTOR_STORE = "flink_agents.integrations.vector_stores.chroma.chroma_vector_store.ChromaVectorStore"

# Java Wrapper
JAVA_WRAPPER_VECTOR_STORE = "flink_agents.api.vector_stores.java_vector_store.JavaVectorStore"
JAVA_WRAPPER_VECTOR_STORE = (
"flink_agents.api.vector_stores.java_vector_store.JavaVectorStore"
)
JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE = "flink_agents.api.vector_stores.java_vector_store.JavaCollectionManageableVectorStore"

class Java:
Expand Down
23 changes: 17 additions & 6 deletions python/flink_agents/plan/actions/chat_model_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from flink_agents.api.resource import ResourceType
from flink_agents.api.runner_context import RunnerContext
from flink_agents.plan.actions.action import Action
from flink_agents.plan.actions.utils import support_async
from flink_agents.plan.function import PythonFunction

if TYPE_CHECKING:
Expand Down Expand Up @@ -179,11 +180,13 @@ async def chat(
)

chat_async = ctx.config.get(AgentExecutionOptions.CHAT_ASYNC)
# java chat model doesn't support async execution,
# see https://github.com/apache/flink-agents/issues/448 for details.
chat_async = chat_async and not isinstance(chat_model, JavaChatModelSetup)

error_handling_strategy = ctx.config.get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY)
if isinstance(chat_model, JavaChatModelSetup) and not support_async():
chat_async = False

error_handling_strategy = ctx.config.get(
AgentExecutionOptions.ERROR_HANDLING_STRATEGY
)
num_retries = 0
if error_handling_strategy == ErrorHandlingStrategy.RETRY:
num_retries = max(0, ctx.config.get(AgentExecutionOptions.MAX_RETRIES))
Expand All @@ -196,8 +199,16 @@ async def chat(
else:
response = ctx.durable_execute(chat_model.chat, messages)

if response.extra_args.get("model_name") and response.extra_args.get("promptTokens") and response.extra_args.get("completionTokens"):
chat_model._record_token_metrics(response.extra_args["model_name"], response.extra_args["promptTokens"], response.extra_args["completionTokens"])
if (
response.extra_args.get("model_name")
and response.extra_args.get("promptTokens")
and response.extra_args.get("completionTokens")
):
chat_model._record_token_metrics(
response.extra_args["model_name"],
response.extra_args["promptTokens"],
response.extra_args["completionTokens"],
)
if output_schema is not None and len(response.tool_calls) == 0:
response = _generate_structured_output(response, output_schema)
break
Expand Down
9 changes: 6 additions & 3 deletions python/flink_agents/plan/actions/context_retrieval_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from flink_agents.api.vector_stores.java_vector_store import JavaVectorStore
from flink_agents.api.vector_stores.vector_store import VectorStoreQuery
from flink_agents.plan.actions.action import Action
from flink_agents.plan.actions.utils import support_async
from flink_agents.plan.function import PythonFunction

_logger = logging.getLogger(__name__)


async def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> None:
"""Built-in action for processing context retrieval requests."""
if isinstance(event, ContextRetrievalRequestEvent):
Expand All @@ -40,9 +42,10 @@ async def process_context_retrieval_request(event: Event, ctx: RunnerContext) ->
query = VectorStoreQuery(query_text=event.query, limit=event.max_results)

rag_async = ctx.config.get(AgentExecutionOptions.RAG_ASYNC)
# java vector store doesn't support async execution
# see https://github.com/apache/flink-agents/issues/448 for details.
rag_async = rag_async and not isinstance(vector_store, JavaVectorStore)

if isinstance(vector_store, JavaVectorStore) and not support_async():
rag_async = False

if rag_async:
# To avoid https://github.com/alibaba/pemja/issues/88,
# we log a message here.
Expand Down
50 changes: 50 additions & 0 deletions python/flink_agents/plan/actions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import functools
import logging
from importlib.metadata import version
from typing import List, Tuple

from packaging import version as pkg_version

UNSUPPORTED_RANGES: List[Tuple[str, str]] = [
("1.0.0", "1.20.3"),
("2.0.0", "2.0.1"),
("2.1.0", "2.1.1"),
("2.2.0", "2.2.0"),
]


@functools.lru_cache(maxsize=1)
def support_async() -> bool:
"""Check whether the current Flink version supports the async execution for
cross-language resource.

The async execution for java resource is supported only on flink with
the pemja 0.6.2 dependency. See https://github.com/apache/flink-agents/pull/571
for details.
"""
current = pkg_version.parse(version("apache-flink"))

for min_ver, max_ver in UNSUPPORTED_RANGES:
if pkg_version.parse(min_ver) <= current <= pkg_version.parse(max_ver):
logging.debug(
f"Flink {current} doesn't support async execution for java resource, will fallback to sync execution."
)
return False
return True
1 change: 1 addition & 0 deletions python/flink_agents/plan/agent_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def get_resource(self, name: str, type: ResourceType) -> Resource:
resource = resource_provider.provide(
get_resource=self.get_resource, config=self.config
)
resource.open()
self.__resources[type][name] = resource
return self.__resources[type][name]

Expand Down
Loading
Loading