Skip to content

Commit 7501c9c

Browse files
committed
merged
2 parents 7f0c040 + 5f50822 commit 7501c9c

File tree

6 files changed

+59
-72
lines changed

6 files changed

+59
-72
lines changed

README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@ Created by Stream, uses [Stream's edge network](https://getstream.io/video/) for
2424

2525
### Sports Coaching
2626

27-
<a href="https://x.com/nash0x7e2/status/1950341779745599769">
28-
<img src="assets/golf_example_tweet.png" alt="Golf Example" style="max-width: 500px; width: 40%">
29-
</a>
30-
3127
This example shows you how to build golf coaching AI with YOLO and OpenAI realtime.
3228
Combining a fast object detection model (like YOLO) with a full realtime AI is useful for many different video AI use cases.
3329
For example: Drone fire detection, sports/video game coaching, physical therapy, workout coaching, just dance style games etc.
@@ -48,7 +44,9 @@ This example shows you how to build golf coaching AI with YOLO and OpenAI realti
4844
Combining a fast object detection model (like YOLO) with a full realtime AI is useful for many different video AI use cases.
4945
For example: Drone fire detection. Sports/video game coaching. Physical therapy. Workout coaching, Just dance style games etc.
5046

51-
[![Golf Example](assets/golf_example_tweet.png)](https://x.com/nash0x7e2/status/1950341779745599769)
47+
<a href="https://x.com/nash0x7e2/status/1950341779745599769">
48+
<img src="assets/golf_example_tweet.png" alt="Golf Example" style="max-width: 500px; width: 40%">
49+
</a>
5250

5351
### Cluely style Invisible Assistant (coming soon)
5452

agents-core/vision_agents/core/agents/agents.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from opentelemetry import trace
1212
from opentelemetry.trace import Tracer
1313

14+
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType
1415
from ..edge import sfu_events
1516
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, CallEndedEvent
16-
from ..edge.types import Connection, Participant, PcmData, TrackType, User
17+
from ..edge.types import Connection, Participant, PcmData, User
1718
from ..events.manager import EventManager
1819
from ..llm.events import (
1920
LLMResponseChunkEvent,
@@ -581,7 +582,7 @@ async def _reply_to_audio(
581582
self.logger.debug(f"🎵 Processing audio from {participant}")
582583
await self.stt.process_audio(pcm_data, participant)
583584

584-
async def _process_track(self, track_id: str, track_type: str, participant):
585+
async def _process_track(self, track_id: str, track_type: int, participant):
585586
# TODO: handle CancelledError
586587
# we only process video tracks
587588
if track_type != TrackType.TRACK_TYPE_VIDEO:

agents-core/vision_agents/core/edge/events.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class TrackAddedEvent(PluginBaseEvent):
1818
"""Event emitted when a track is added to the call."""
1919
type: str = field(default='plugin.edge.track_added', init=False)
2020
track_id: Optional[str] = None
21-
track_type: Optional[str] = None
21+
track_type: Optional[int] = None
2222
user: Optional[Any] = None
2323

2424

@@ -27,7 +27,7 @@ class TrackRemovedEvent(PluginBaseEvent):
2727
"""Event emitted when a track is removed from the call."""
2828
type: str = field(default='plugin.edge.track_removed', init=False)
2929
track_id: Optional[str] = None
30-
track_type: Optional[str] = None
30+
track_type: Optional[int] = None
3131
user: Optional[Any] = None
3232

3333

agents-core/vision_agents/core/edge/types.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#from __future__ import annotations
22
from dataclasses import dataclass
3-
from enum import StrEnum
43
from typing import Any, Optional, NamedTuple
54
import logging
65

@@ -25,20 +24,6 @@ class Participant:
2524
user_id: str
2625

2726

28-
class TrackType(StrEnum):
29-
TRACK_TYPE_UNSPECIFIED = "unspecified"
30-
TRACK_TYPE_AUDIO = "audio"
31-
TRACK_TYPE_VIDEO = "video"
32-
TRACK_TYPE_SCREEN_SHARE = "screen_share" # TODO: Verify its correct
33-
TRACK_TYPE_SCREEN_SHARE_AUDIO = "screen_share_audio"
34-
35-
TRACK_TYPE_UNSPECIFIED = TrackType.TRACK_TYPE_UNSPECIFIED
36-
TRACK_TYPE_AUDIO = TrackType.TRACK_TYPE_AUDIO
37-
TRACK_TYPE_VIDEO = TrackType.TRACK_TYPE_VIDEO
38-
TRACK_TYPE_SCREEN_SHARE = TrackType.TRACK_TYPE_SCREEN_SHARE
39-
TRACK_TYPE_SCREEN_SHARE_AUDIO = TrackType.TRACK_TYPE_SCREEN_SHARE_AUDIO
40-
41-
4227
class Connection(AsyncIOEventEmitter):
4328
"""
4429
To standardize we need to have a method to close
@@ -195,4 +180,4 @@ def resample(self, target_sample_rate: int) -> "PcmData":
195180
)
196181
else:
197182
# If resampling failed, return original data
198-
return self
183+
return self

agents-core/vision_agents/core/events/manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,8 @@ def register(self, event_class, ignore_not_compatible=False):
180180
ValueError: If event_class doesn't meet requirements and ignore_not_compatible is False
181181
"""
182182
if event_class.__name__.endswith('Event') and hasattr(event_class, 'type'):
183-
#if event_class.type in self._events:
184-
# raise KeyError(f"{event_class.type} is already registered.")
185183
self._events[event_class.type] = event_class
186-
logger.info(f"Registered new event {event_class} - {event_class.type}")
184+
logger.debug(f"Registered new event {event_class} - {event_class.type}")
187185
elif event_class.__name__.endswith('BaseEvent'):
188186
return
189187
elif not ignore_not_compatible:

plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import webbrowser
55
from typing import Optional, TYPE_CHECKING
66
from urllib.parse import urlencode
7-
from uuid import uuid4
87

98
import aiortc
109
from getstream import AsyncStream
@@ -57,13 +56,14 @@ def __init__(self, **kwargs):
5756
self.channel: Optional[Channel] = None
5857
self.conversation: Optional[StreamConversation] = None
5958
self.channel_type = "videocall"
59+
self.agent_user_id: str | None = None
6060
# Track mapping: (user_id, session_id, track_type_int) -> {"track_id": str, "published": bool}
6161
# track_type_int is from TrackType enum (e.g., TrackType.TRACK_TYPE_AUDIO)
6262
self._track_map: dict = {}
6363
# Temporary storage for tracks before SFU confirms their type
6464
# track_id -> (user_id, session_id, webrtc_type_string)
6565
self._pending_tracks: dict = {}
66-
66+
6767
# Register event handlers
6868
self.events.subscribe(self._on_track_published)
6969
self.events.subscribe(self._on_track_removed)
@@ -81,30 +81,36 @@ def _get_webrtc_kind(self, track_type_int: int) -> str:
8181

8282
async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
8383
"""Handle track published events from SFU - spawn TrackAddedEvent with correct type."""
84-
if not event.participant or not event.payload:
84+
if not event.payload:
8585
return
86-
87-
user_id = event.user_id
88-
session_id = event.payload.session_id
86+
87+
if event.participant.user_id:
88+
session_id = event.participant.session_id
89+
user_id = event.participant.user_id
90+
else:
91+
user_id = event.payload.user_id
92+
session_id = event.payload.session_id
93+
94+
if user_id == self.agent_user_id:
95+
return
96+
8997
track_type_int = event.payload.type # TrackType enum int from SFU
90-
track_type_name = TrackType.Name(track_type_int)
91-
# Get expected WebRTC kind for this track type
9298
expected_kind = self._get_webrtc_kind(track_type_int)
9399
track_key = (user_id, session_id, track_type_int)
94100

95101
# First check if track already exists in map (e.g., from previous unpublish/republish)
96102
if track_key in self._track_map:
97103
self._track_map[track_key]["published"] = True
98-
self.logger.debug(f"Track marked as published (already existed): {track_key}")
104+
self.logger.info(f"Track marked as published (already existed): {track_key}")
99105
return
100106

101107
# Wait for pending track to be populated (with 10 second timeout)
102108
# SFU might send TrackPublishedEvent before WebRTC processes track_added
103109
track_id = None
104110
timeout = 10.0
105-
poll_interval = 0.01 # 10ms
111+
poll_interval = 0.01
106112
elapsed = 0.0
107-
113+
108114
while elapsed < timeout:
109115
# Find pending track for this user/session with matching kind
110116
for tid, (pending_user, pending_session, pending_kind) in list(self._pending_tracks.items()):
@@ -125,19 +131,19 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
125131
if track_id:
126132
# Store with correct type from SFU
127133
self._track_map[track_key] = {"track_id": track_id, "published": True}
128-
self.logger.info(f"Trackmap published: {track_type_name} from {user_id}, track_id: {track_id} (waited {elapsed:.2f}s)")
134+
self.logger.info(f"Trackmap published: {track_type_int} from {user_id}, track_id: {track_id} (waited {elapsed:.2f}s)")
129135

130136
# NOW spawn TrackAddedEvent with correct type
131137
self.events.send(events.TrackAddedEvent(
132138
plugin_name="getstream",
133139
track_id=track_id,
134-
track_type=track_type_name,
140+
track_type=track_type_int,
135141
user=event.participant,
136142
user_metadata=event.participant
137143
))
138144
else:
139145
raise TimeoutError(
140-
f"Timeout waiting for pending track: {track_type_name} ({expected_kind}) from user {user_id}, "
146+
f"Timeout waiting for pending track: {track_type_int} ({expected_kind}) from user {user_id}, "
141147
f"session {session_id}. Waited {timeout}s but WebRTC track_added with matching kind was never received."
142148
f"Pending tracks: {self._pending_tracks}\n"
143149
f"Key: {track_key}\n"
@@ -146,39 +152,41 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
146152

147153
async def _on_track_removed(self, event: sfu_events.ParticipantLeftEvent | sfu_events.TrackUnpublishedEvent):
148154
"""Handle track unpublished and participant left events."""
149-
if not event.participant:
155+
if not event.payload: # NOTE: mypy typecheck
150156
return
151-
152-
# Extract fields based on event type
157+
153158
participant = event.participant
154-
user_id = participant.user_id
155-
session_id = participant.session_id
156-
159+
if participant and participant.user_id:
160+
user_id = participant.user_id
161+
session_id = participant.session_id
162+
else:
163+
user_id = event.payload.user_id
164+
session_id = event.payload.session_id
165+
157166
# Determine which tracks to remove
158167
if hasattr(event.payload, 'type') and event.payload is not None:
159168
# TrackUnpublishedEvent - single track
160169
tracks_to_remove = [event.payload.type]
161170
event_desc = "Track unpublished"
162171
else:
163172
# ParticipantLeftEvent - all published tracks
164-
tracks_to_remove = participant.published_tracks or []
173+
tracks_to_remove = event.participant.published_tracks or []
165174
event_desc = "Participant left"
166175

167176
track_names = [TrackType.Name(t) for t in tracks_to_remove]
168177
self.logger.info(f"{event_desc}: {user_id}, tracks: {track_names}")
169178

170179
# Mark each track as unpublished and send TrackRemovedEvent
171180
for track_type_int in tracks_to_remove:
172-
track_type_name = TrackType.Name(track_type_int)
173181
track_key = (user_id, session_id, track_type_int)
174182
track_info = self._track_map.get(track_key)
175-
183+
176184
if track_info:
177185
track_id = track_info["track_id"]
178186
self.events.send(events.TrackRemovedEvent(
179187
plugin_name="getstream",
180188
track_id=track_id,
181-
track_type=track_type_name,
189+
track_type=track_type_int,
182190
user=participant,
183191
user_metadata=participant
184192
))
@@ -200,6 +208,7 @@ async def create_conversation(self, call: Call, user, instructions):
200208
return self.conversation
201209

202210
async def create_user(self, user: User):
211+
self.agent_user_id = user.id
203212
return await self.client.create_user(name=user.name, id=user.id)
204213

205214
async def join(self, agent: "Agent", call: Call) -> StreamConnection:
@@ -222,18 +231,19 @@ async def join(self, agent: "Agent", call: Call) -> StreamConnection:
222231
default=self._get_subscription_config()
223232
)
224233

225-
try:
226-
# Open RTC connection and keep it alive for the duration of the returned context manager
227-
connection = await rtc.join(
228-
call, agent.agent_user.id, subscription_config=subscription_config
229-
)
230-
await connection.__aenter__() # TODO: weird API? there should be a manual version
231-
except Exception:
232-
raise
234+
# Open RTC connection and keep it alive for the duration of the returned context manager
235+
connection = await rtc.join(
236+
call, agent.agent_user.id, subscription_config=subscription_config
237+
)
233238

234-
self._connection = connection
239+
@connection.on("track_added")
240+
async def on_track(track_id, track_type, user):
241+
# Store track in pending map - wait for SFU to confirm type before spawning TrackAddedEvent
242+
self._pending_tracks[track_id] = (user.user_id, user.session_id, track_type)
243+
self.logger.info(f"Track received from WebRTC (pending SFU confirmation): {track_id}, type: {track_type}, user: {user.user_id}")
235244

236-
@self._connection.on("audio")
245+
self.events.silent(events.AudioReceivedEvent)
246+
@connection.on("audio")
237247
async def on_audio_received(pcm: PcmData, participant: Participant):
238248
self.events.send(events.AudioReceivedEvent(
239249
plugin_name="getstream",
@@ -242,17 +252,10 @@ async def on_audio_received(pcm: PcmData, participant: Participant):
242252
user_metadata=participant
243253
))
244254

245-
self.events.silent(events.AudioReceivedEvent)
246-
247-
@self._connection.on("track_added")
248-
async def on_track(track_id, track_type, user):
249-
# Store track in pending map - wait for SFU to confirm type before spawning TrackAddedEvent
250-
self._pending_tracks[track_id] = (user.user_id, user.session_id, track_type)
251-
self.logger.info(f"Track received from WebRTC (pending SFU confirmation): {track_id}, type: {track_type}, user: {user.user_id}")
252-
255+
await connection.__aenter__() # TODO: weird API? there should be a manual version
256+
self._connection = connection
253257

254258
standardize_connection = StreamConnection(connection)
255-
256259
return standardize_connection
257260

258261
def create_audio_track(self, framerate: int = 48000, stereo: bool = True):
@@ -293,9 +296,11 @@ async def open_demo(self, call: Call) -> str:
293296
client = call.client.stream
294297

295298
# Create a human user for testing
296-
human_id = f"user-{uuid4()}"
299+
human_id = "user-demo-agent"
297300
name = "Human User"
298301

302+
# Create the user in the GetStream system
303+
await client.create_user(name=name, id=human_id)
299304
# Create user token for browser access
300305
token = client.create_token(human_id, expiration=3600)
301306

0 commit comments

Comments
 (0)