22import re
33import time
44import types
5+ from asyncio import get_event_loop_policy
56from functools import partial
6- from typing import Dict
7+ from typing import TYPE_CHECKING , Dict , Optional
78
89import traitlets
910from dask .distributed import Client as DaskClient
1213from jupyter_ai_magics .utils import get_em_providers , get_lm_providers
1314from jupyter_events import EventLogger
1415from jupyter_server .extension .application import ExtensionApp
15- from jupyter_server .utils import url_path_join
1616from jupyterlab_chat .models import Message
1717from jupyterlab_chat .ychat import YChat
1818from pycrdt import ArrayEvent
2222from .chat_handlers .base import BaseChatHandler
2323from .completions .handlers import DefaultInlineCompletionHandler
2424from .config_manager import ConfigManager
25- from .constants import BOT
2625from .context_providers import BaseCommandContextProvider , FileContextProvider
2726from .handlers import (
2827 ApiKeysHandler ,
3332 SlashCommandsInfoHandler ,
3433)
3534from .history import YChatHistory
35+ from .personas import PersonaManager
36+
37+ if TYPE_CHECKING :
38+ from asyncio import AbstractEventLoop
3639
3740from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip
3841 __version__ as jupyter_collaboration_version ,
@@ -244,6 +247,13 @@ def initialize(self):
244247 schema_id = JUPYTER_COLLABORATION_EVENTS_URI , listener = self .connect_chat
245248 )
246249
250+ @property
251+ def event_loop (self ) -> "AbstractEventLoop" :
252+ """
253+ Returns a reference to the asyncio event loop.
254+ """
255+ return get_event_loop_policy ().get_event_loop ()
256+
247257 async def connect_chat (
248258 self , logger : EventLogger , schema_id : str , data : dict
249259 ) -> None :
@@ -264,17 +274,19 @@ async def connect_chat(
264274 if ychat is None :
265275 return
266276
267- # Add the bot user to the chat document awareness.
268- BOT ["avatar_url" ] = url_path_join (
269- self .settings .get ("base_url" , "/" ), "api/ai/static/jupyternaut.svg"
270- )
271- if ychat .awareness is not None :
272- ychat .awareness .set_local_state_field ("user" , BOT )
273-
274277 # initialize chat handlers for new chat
275278 self .chat_handlers_by_room [room_id ] = self ._init_chat_handlers (ychat )
276279
277- callback = partial (self .on_change , room_id )
280+ # initialize persona manager
281+ persona_manager = self ._init_persona_manager (ychat )
282+ if not persona_manager :
283+ self .log .error (
284+ "Jupyter AI was unable to initialize its AI personas. They are not available for use in chat until this error is resolved. "
285+ + "Please verify your configuration and open a new issue on GitHub if this error persists."
286+ )
287+ return
288+
289+ callback = partial (self .on_change , room_id , persona_manager )
278290 ychat .ymessages .observe (callback )
279291
280292 async def get_chat (self , room_id : str ) -> YChat :
@@ -301,21 +313,26 @@ async def get_chat(self, room_id: str) -> YChat:
301313 self .ychats_by_room [room_id ] = document
302314 return document
303315
304- def on_change (self , room_id : str , events : ArrayEvent ) -> None :
316+ def on_change (
317+ self , room_id : str , persona_manager : PersonaManager , events : ArrayEvent
318+ ) -> None :
305319 assert self .serverapp
306320
307321 for change in events .delta : # type:ignore[attr-defined]
308322 if not "insert" in change .keys ():
309323 continue
310- messages = change ["insert" ]
311- for message_dict in messages :
312- message = Message (** message_dict )
313- if message .sender == BOT ["username" ] or message .raw_time :
314- continue
315324
316- self .serverapp .io_loop .asyncio_loop .create_task ( # type:ignore[attr-defined]
317- self .route_human_message (room_id , message )
318- )
325+ # the "if not m['raw_time']" clause is necessary because every new
326+ # message triggers 2 events, one with `raw_time` set to `True` and
327+ # another with `raw_time` set to `False` milliseconds later.
328+ # we should explore fixing this quirk in Jupyter Chat.
329+ #
330+ # Ref: https://github.com/jupyterlab/jupyter-chat/issues/212
331+ new_messages = [
332+ Message (** m ) for m in change ["insert" ] if not m .get ("raw_time" , False )
333+ ]
334+ for new_message in new_messages :
335+ persona_manager .route_message (new_message )
319336
320337 async def route_human_message (self , room_id : str , message : Message ):
321338 """
@@ -400,18 +417,15 @@ def initialize_settings(self):
400417
401418 self .log .info (f"Registered { self .name } server extension" )
402419
403- # get reference to event loop
404- # `asyncio.get_event_loop()` is deprecated in Python 3.11+, in favor of
405- # the more readable `asyncio.get_event_loop_policy().get_event_loop()`.
406- # it's easier to just reference the loop directly.
407- loop = self .serverapp .io_loop .asyncio_loop
408- self .settings ["jai_event_loop" ] = loop
420+ self .settings ["jai_event_loop" ] = self .event_loop
409421
410422 # We cannot instantiate the Dask client directly here because it
411423 # requires the event loop to be running on init. So instead we schedule
412424 # this as a task that is run as soon as the loop starts, and pass
413425 # consumers a Future that resolves to the Dask client when awaited.
414- self .settings ["dask_client_future" ] = loop .create_task (self ._get_dask_client ())
426+ self .settings ["dask_client_future" ] = self .event_loop .create_task (
427+ self ._get_dask_client ()
428+ )
415429
416430 # Create empty context providers dict to be filled later.
417431 # This is created early to use as kwargs for chat handlers.
@@ -456,10 +470,7 @@ async def _stop_extension(self):
456470
457471 def _init_chat_handlers (self , ychat : YChat ) -> Dict [str , BaseChatHandler ]:
458472 """
459- Initializes a set of chat handlers. May accept a YChat instance for
460- collaborative chats.
461-
462- TODO: Make `ychat` required once Jupyter Chat migration is complete.
473+ Initializes a set of chat handlers for a given `YChat` instance.
463474 """
464475 assert self .serverapp
465476
@@ -606,3 +617,32 @@ def _init_context_providers(self):
606617 ** context_providers_kwargs
607618 )
608619 self .log .info (f"Registered context provider `{ context_provider .id } `." )
620+
621+ def _init_persona_manager (self , ychat : YChat ) -> Optional [PersonaManager ]:
622+ """
623+ Initializes a `PersonaManager` instance scoped to a `YChat`.
624+
625+ This method should not raise an exception. Upon encountering an
626+ exception, this method will catch it, log it, and return `None`.
627+ """
628+ persona_manager : Optional [PersonaManager ]
629+
630+ try :
631+ config_manager = self .settings .get ("jai_config_manager" , None )
632+ assert config_manager and isinstance (config_manager , ConfigManager )
633+
634+ persona_manager = PersonaManager (
635+ ychat = ychat ,
636+ config_manager = config_manager ,
637+ event_loop = self .event_loop ,
638+ log = self .log ,
639+ )
640+ except Exception as e :
641+ # TODO: how to stop the extension when this fails
642+ # also why do uncaught exceptions produce an empty error log in Jupyter Server?
643+ self .log .error (
644+ f"Unable to initialize PersonaManager in YChat with ID '{ ychat .get_id ()} ' due to an exception printed below."
645+ )
646+ self .log .exception (e )
647+ finally :
648+ return persona_manager
0 commit comments