-
Notifications
You must be signed in to change notification settings - Fork 212
Expand file tree
/
Copy pathchat.py
More file actions
198 lines (166 loc) · 6.86 KB
/
chat.py
File metadata and controls
198 lines (166 loc) · 6.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple turn by turn CLI chat.
Reads a line from the user, sends it to the model, streams the response. While
the chat itself is text-only, one can type a URL and it will be replaced with
its content. This allows sending images and other modalities to the model.
"""
import asyncio
import contextlib
import os
from typing import AsyncIterable, Sequence
from absl import app
from absl import flags
from genai_processors import content_api
from genai_processors import processor
from genai_processors.core import function_calling
from genai_processors.core import pdf
from genai_processors.core import realtime
from genai_processors.core import text
from genai_processors.dev import trace_file
from genai_processors.examples import mcp as mcp_examples
from genai_processors.examples import models
from google.genai import types as genai_types
import httpx
_MCP_SERVER = flags.DEFINE_string(
'mcp_server',
None,
'Address of the MCP server to use. Use "demo" to use a demo server. Use an'
' https address starting with "https://" to use a remote server. Use'
' "local:<command>" to use a local server, e.g. local:npx -y'
' @modelcontextprotocol/server-filesystem .',
)
_API_KEY_ENV = flags.DEFINE_string(
'api_key_env',
'API_KEY',
'Name of the environment variable containing the API key for remote MCP'
' servers. Defined by `export API_KEY=...`.',
)
_API_KEY_HEADER = flags.DEFINE_string(
'api_key_header',
'X-Goog-Api-Key',
'Name of the header containing the API key for remote MCP servers.',
)
_TRACE_DIR = flags.DEFINE_string(
'trace_dir',
None,
'If set, enable tracing and write traces to this directory.',
)
SYSTEM_INSTRUCTIONS = [
'You are an agent that interacts with the user in a conversation. Make'
' the conversation lively and interesting for the user. You can make jokes,'
' explain interesting facts, predict what could happen, etc. Respond to the'
' user in a few sentences maximum: keep it short and engaging.'
]
USER_PROMPT = '\n> '
def _get_mcp_session():
"""Returns a context manager for an MCP session."""
server_val = _MCP_SERVER.value
if not server_val:
raise ValueError('MCP server address must be provided')
if server_val == 'demo':
return mcp_examples.get_demo_mcp_session()
elif server_val.startswith('https://'):
if _API_KEY_ENV.value:
api_key = os.environ.get(_API_KEY_ENV.value)
if not api_key:
raise ValueError(
f'API key not found in environment variable: {_API_KEY_ENV.value}'
)
api_key_header = {_API_KEY_HEADER.value: api_key}
else:
api_key_header = None
return mcp_examples.get_remote_mcp_session(server_val, api_key_header)
elif server_val.startswith('local:'):
return mcp_examples.get_local_mcp_session(server_val[6:])
else:
raise ValueError(
f'Unsupported MCP server: {server_val}. Use one of the'
' following:\n- demo\n- https://<address>\n- local:<command>'
)
class _FetchUrl(processor.PartProcessor):
"""A PartProcessor that fetches the content for a given URL.
DO NOT USE OUTSIDE OF THIS EXAMPLE: NOT PRODUCTION QUALITY.
This is an oversimplified version of FetchUrl to allow testing multimodal
content handling (images, PDFs). It will be replaced with a proper version
from core.web once it is available.
"""
def match(self, part: content_api.ProcessorPart) -> bool:
"""This processor matches on WebRequest parts."""
return content_api.is_dataclass(part.mimetype, text.FetchRequest)
@processor.yield_exceptions_as_parts
async def call(
self, part: content_api.ProcessorPart
) -> AsyncIterable[content_api.ProcessorPartTypes]:
"""Gets the content for a given URL."""
webrequest = part.get_dataclass(text.FetchRequest)
async with httpx.AsyncClient(follow_redirects=True) as client:
response = await client.get(webrequest.url)
response.raise_for_status()
yield content_api.ProcessorPart(
response.content, mimetype=response.headers.get('content-type')
)
async def run_chat() -> None:
"""Runs a simple turn by turn chat."""
# The easiest way to track context between turns is to use Gemini Live API
# genai_procesors.core.live_model.LiveProcessor. We then can send user
# turns in and it will yield model responses.
#
# Here we take a more flexible but slightly more complex approach and use
# genai_procesors.core.realtime.LiveProcessor - a client-side version of the
# Live API. It wraps any turn-based model and provides a bidirectional
# interface. It also supports customizable context compression.
# See models.py for the list of supported models and flags used to select one.
async with contextlib.AsyncExitStack() as stack:
if _MCP_SERVER.value is not None:
mcp_session = await stack.enter_async_context(_get_mcp_session())
tools = [mcp_session]
fns = [mcp_session]
else:
tools = [genai_types.Tool(google_search=genai_types.GoogleSearch())]
fns = []
if _TRACE_DIR.value:
await stack.enter_async_context(
trace_file.SyncFileTrace(trace_dir=_TRACE_DIR.value, name='chat')
)
model = models.turn_based_model(
system_instruction=SYSTEM_INSTRUCTIONS,
)
model = function_calling.FunctionCalling(
model=realtime.LiveModelProcessor(model),
fns=fns,
is_bidi_model=True,
)
# Give the agent the ability to download multimodal content.
chat_agent = text.UrlExtractor() + _FetchUrl() + pdf.PDFExtract() + model
print('Welcome to the GenAI Processor Chat! Ask me anything.')
print('You can also ask questions about images or PDFs by providing a URL.')
print('For example:')
print(
' - Describe the main points from the '
' https://storage.googleapis.com/gweb-developer-goog-blog-assets/images/gemini_2-5_ga_family_1-1__dark.original.png'
' diagram.'
)
print(' - Summarize https://arxiv.org/pdf/2312.11805')
print('Press Ctrl + D to exit.')
print(USER_PROMPT, end='', flush=True)
await text.terminal_output(
chat_agent(text.terminal_input()), prompt=USER_PROMPT
)
def main(argv: Sequence[str]):
del argv # Unused.
asyncio.run(run_chat())
if __name__ == '__main__':
app.run(main)