Skip to content

Commit 723fe10

Browse files
committed
support restarting bidi
1 parent e905ef0 commit 723fe10

File tree

6 files changed

+204
-14
lines changed

6 files changed

+204
-14
lines changed

src/app/components/chat-panel/chat-panel.component.scss

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@
381381
.video-rec-btn {
382382
background-color: var(--chat-card-background-color);
383383
&.recording {
384-
background-color: var(--chat-panel-eval-fail-color);
384+
background-color: var(--chat-panel-eval-fail-color) !important;
385+
color: white !important;
385386
}
386387
}

src/app/components/chat/chat.component.spec.ts

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -563,15 +563,93 @@ describe('ChatComponent', () => {
563563

564564
describe('when bidi streaming is restarted', () => {
565565
beforeEach(() => {
566-
component.sessionHasUsedBidi.add(component.sessionId);
566+
component.startAudioRecording();
567+
component.stopAudioRecording();
567568
component.startAudioRecording();
568569
});
569-
it('should show snackbar', () => {
570-
expect(mockSnackBar.open)
571-
.toHaveBeenCalledWith(
572-
'Restarting bidirectional streaming is not currently supported. Please refresh the page or start a new session.',
573-
OK_BUTTON_TEXT,
574-
);
570+
it('should allow restart without error', () => {
571+
expect(component.isAudioRecording).toBe(true);
572+
expect(mockStreamChatService.startAudioChat).toHaveBeenCalledTimes(2);
573+
});
574+
});
575+
576+
describe('when audio recording is stopped and restarted', () => {
577+
beforeEach(() => {
578+
component.startAudioRecording();
579+
expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(true);
580+
component.stopAudioRecording();
581+
});
582+
583+
it('should remove session from sessionHasUsedBidi set', () => {
584+
expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(false);
585+
});
586+
587+
it('should allow restarting audio recording', () => {
588+
component.startAudioRecording();
589+
expect(mockSnackBar.open).not.toHaveBeenCalled();
590+
expect(component.isAudioRecording).toBe(true);
591+
});
592+
});
593+
594+
describe('when video recording is stopped and restarted', () => {
595+
beforeEach(() => {
596+
component.startVideoRecording();
597+
expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(true);
598+
component.stopVideoRecording();
599+
});
600+
601+
it('should remove session from sessionHasUsedBidi set', () => {
602+
expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(false);
603+
});
604+
605+
it('should allow restarting video recording', () => {
606+
component.startVideoRecording();
607+
expect(mockSnackBar.open).not.toHaveBeenCalled();
608+
expect(component.isVideoRecording).toBe(true);
609+
});
610+
});
611+
612+
describe('when trying to start concurrent bidi streams', () => {
613+
it('should prevent starting audio while already recording', () => {
614+
component.startAudioRecording();
615+
expect(component.isAudioRecording).toBe(true);
616+
617+
component.startAudioRecording();
618+
619+
expect(mockSnackBar.open).toHaveBeenCalledWith(
620+
'Another streaming request is already in progress. Please stop it before starting a new one.',
621+
'OK'
622+
);
623+
expect(mockStreamChatService.startAudioChat).toHaveBeenCalledTimes(1);
624+
});
625+
626+
it('should prevent starting video while already recording', () => {
627+
component.startVideoRecording();
628+
expect(component.isVideoRecording).toBe(true);
629+
630+
component.startVideoRecording();
631+
632+
expect(mockSnackBar.open).toHaveBeenCalledWith(
633+
'Another streaming request is already in progress. Please stop it before starting a new one.',
634+
'OK'
635+
);
636+
expect(mockStreamChatService.startVideoChat).toHaveBeenCalledTimes(1);
637+
});
638+
});
639+
640+
describe('when stopping video recording without videoContainer', () => {
641+
it('should still cleanup sessionHasUsedBidi', () => {
642+
component.startVideoRecording();
643+
expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(true);
644+
645+
spyOn(component, 'chatPanel').and.returnValue({
646+
videoContainer: undefined
647+
} as any);
648+
649+
component.stopVideoRecording();
650+
651+
expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(false);
652+
expect(component.isVideoRecording).toBe(false);
575653
});
576654
});
577655
});

src/app/components/chat/chat.component.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class CustomPaginatorIntl extends MatPaginatorIntl {
107107
}
108108

109109
const BIDI_STREAMING_RESTART_WARNING =
110-
'Restarting bidirectional streaming is not currently supported. Please refresh the page or start a new session.';
110+
'Another streaming request is already in progress. Please stop it before starting a new one.';
111111

112112
@Component({
113113
selector: 'app-chat',
@@ -182,8 +182,7 @@ export class ChatComponent implements OnInit, AfterViewInit, OnDestroy {
182182
private readonly isModelThinkingSubject = new BehaviorSubject(false);
183183
private readonly scrollInterruptedSubject = new BehaviorSubject(false);
184184

185-
// TODO: Remove this once backend supports restarting bidi streaming.
186-
sessionHasUsedBidi = new Set<string>();
185+
public sessionHasUsedBidi = new Set<string>();
187186

188187
eventData = new Map<string, any>();
189188
traceData: any[] = [];
@@ -987,11 +986,14 @@ export class ChatComponent implements OnInit, AfterViewInit, OnDestroy {
987986
{role: 'bot', text: 'Speaking...'},
988987
]);
989988
this.sessionHasUsedBidi.add(this.sessionId);
989+
this.changeDetectorRef.detectChanges();
990990
}
991991

992992
stopAudioRecording() {
993993
this.streamChatService.stopAudioChat();
994994
this.isAudioRecording = false;
995+
this.sessionHasUsedBidi.delete(this.sessionId);
996+
this.changeDetectorRef.detectChanges();
995997
}
996998

997999
toggleVideoRecording() {
@@ -1018,15 +1020,17 @@ export class ChatComponent implements OnInit, AfterViewInit, OnDestroy {
10181020
this.messages.update(
10191021
messages => [...messages, {role: 'user', text: 'Speaking...'}]);
10201022
this.sessionHasUsedBidi.add(this.sessionId);
1023+
this.changeDetectorRef.detectChanges();
10211024
}
10221025

10231026
stopVideoRecording() {
10241027
const videoContainer = this.chatPanel()?.videoContainer;
1025-
if (!videoContainer) {
1026-
return;
1028+
if (videoContainer) {
1029+
this.streamChatService.stopVideoChat(videoContainer);
10271030
}
1028-
this.streamChatService.stopVideoChat(videoContainer);
10291031
this.isVideoRecording = false;
1032+
this.sessionHasUsedBidi.delete(this.sessionId);
1033+
this.changeDetectorRef.detectChanges();
10301034
}
10311035

10321036
private getAsyncFunctionsFromParts(pendingIds: any[], parts: any[], invocationId: string) {

src/app/core/services/stream-chat.service.spec.ts

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,4 +260,64 @@ describe('StreamChatService', () => {
260260
expect(mockWebSocketService.sendMessage).toHaveBeenCalledTimes(2);
261261
}));
262262
});
263+
264+
describe('restart audio chat', () => {
265+
it('should allow restarting audio chat after stopping', async () => {
266+
mockAudioService.getCombinedAudioBuffer.and.returnValue(
267+
Uint8Array.of());
268+
269+
await service.startAudioChat({
270+
appName: 'fake-app-name',
271+
userId: 'fake-user-id',
272+
sessionId: 'fake-session-id'
273+
});
274+
expect(mockWebSocketService.connect).toHaveBeenCalledTimes(1);
275+
expect(mockAudioService.startRecording).toHaveBeenCalledTimes(1);
276+
277+
service.stopAudioChat();
278+
expect(mockAudioService.stopRecording).toHaveBeenCalledTimes(1);
279+
expect(mockWebSocketService.closeConnection).toHaveBeenCalledTimes(1);
280+
281+
await service.startAudioChat({
282+
appName: 'fake-app-name',
283+
userId: 'fake-user-id',
284+
sessionId: 'fake-session-id'
285+
});
286+
expect(mockWebSocketService.connect).toHaveBeenCalledTimes(2);
287+
expect(mockAudioService.startRecording).toHaveBeenCalledTimes(2);
288+
});
289+
});
290+
291+
describe('restart video chat', () => {
292+
it('should allow restarting video chat after stopping', async () => {
293+
mockAudioService.getCombinedAudioBuffer.and.returnValue(
294+
Uint8Array.of());
295+
mockVideoService.getCapturedFrame.and.resolveTo(Uint8Array.of());
296+
297+
await service.startVideoChat({
298+
appName: 'fake-app-name',
299+
userId: 'fake-user-id',
300+
sessionId: 'fake-session-id',
301+
videoContainer
302+
});
303+
expect(mockWebSocketService.connect).toHaveBeenCalledTimes(1);
304+
expect(mockAudioService.startRecording).toHaveBeenCalledTimes(1);
305+
expect(mockVideoService.startRecording).toHaveBeenCalledTimes(1);
306+
307+
service.stopVideoChat(videoContainer);
308+
expect(mockAudioService.stopRecording).toHaveBeenCalledTimes(1);
309+
expect(mockVideoService.stopRecording).toHaveBeenCalledTimes(1);
310+
expect(mockWebSocketService.closeConnection).toHaveBeenCalledTimes(1);
311+
312+
await service.startVideoChat({
313+
appName: 'fake-app-name',
314+
userId: 'fake-user-id',
315+
sessionId: 'fake-session-id',
316+
videoContainer
317+
});
318+
expect(mockWebSocketService.connect).toHaveBeenCalledTimes(2);
319+
expect(mockAudioService.startRecording).toHaveBeenCalledTimes(2);
320+
expect(mockVideoService.startRecording).toHaveBeenCalledTimes(2);
321+
});
322+
});
263323
});

src/app/core/services/websocket.service.spec.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ describe('WebSocketService', () => {
3737
buffer: null,
3838
}),
3939
currentTime: 0,
40+
state: 'running',
41+
close: jasmine.createSpy('close').and.callFake(function(this: any) {
42+
this.state = 'closed';
43+
}),
4044
};
4145
spyOn(window, 'AudioContext').and.returnValue(mockAudioContext);
4246

@@ -60,4 +64,37 @@ describe('WebSocketService', () => {
6064
expect(service.urlSafeBase64ToBase64('abcd')).toEqual('abcd');
6165
});
6266
});
67+
68+
describe('connection restart', () => {
69+
it('should close audio context when closing connection', () => {
70+
service.closeConnection();
71+
expect(mockAudioContext.close).toHaveBeenCalled();
72+
expect(mockAudioContext.state).toBe('closed');
73+
});
74+
75+
it('should create new audio context when reconnecting after close', () => {
76+
service.closeConnection();
77+
expect(mockAudioContext.state).toBe('closed');
78+
79+
const audioContextCallCount = (window.AudioContext as any).calls.count();
80+
service.connect('ws://test');
81+
expect((window.AudioContext as any).calls.count()).toBe(audioContextCallCount + 1);
82+
});
83+
84+
it('should reset audio buffer when reconnecting', () => {
85+
service.connect('ws://test1');
86+
87+
(service as any).audioBuffer = [new Uint8Array([1, 2, 3])];
88+
89+
service.connect('ws://test2');
90+
expect((service as any).audioBuffer).toEqual([]);
91+
});
92+
93+
it('should reset lastAudioTime when reconnecting', () => {
94+
service.connect('ws://test1');
95+
(service as any).lastAudioTime = 5.5;
96+
service.connect('ws://test2');
97+
expect((service as any).lastAudioTime).toBe(0);
98+
});
99+
});
63100
});

src/app/core/services/websocket.service.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ export class WebSocketService {
3939
constructor() {}
4040

4141
connect(serverUrl: string) {
42+
this.closeConnection();
43+
if (this.audioContext.state === 'closed') {
44+
this.audioContext = new AudioContext({sampleRate: 22000});
45+
}
46+
this.lastAudioTime = 0;
47+
this.audioBuffer = [];
48+
4249
this.socket$ = new WebSocketSubject({
4350
url: serverUrl,
4451
serializer: (msg) => JSON.stringify(msg),
@@ -76,6 +83,9 @@ export class WebSocketService {
7683
if (this.socket$) {
7784
this.socket$.complete();
7885
}
86+
if (this.audioContext && this.audioContext.state !== 'closed') {
87+
this.audioContext.close();
88+
}
7989
}
8090

8191
getMessages() {

0 commit comments

Comments
 (0)