Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sleap_client"
version = "0.0.2a52"
version = "0.0.4a6"
authors = [
{ name="Amick Licup", email="[email protected]" }
]
Expand All @@ -20,8 +20,10 @@ classifiers = [
]
dependencies = [
"aiortc",
"asyncio",
"websockets",
"websockets",
"jsonpickle",
"pyzmq",
"qtpy",
]

[project.urls]
Expand Down
84 changes: 81 additions & 3 deletions sleap_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import os

from aiortc import RTCPeerConnection, RTCSessionDescription, RTCDataChannel
from functools import partial
from qtpy import QtCore
from sleap.gui.widgets.monitor import LossViewer
from sleap.gui.widgets.imagedir import QtImageDirectoryWidget
from sleap.gui.learning.configs import ConfigFileInfo
from sleap.nn.config.training_job import TrainingJobConfig
from websockets.client import ClientConnection

# Setup logging.
Expand All @@ -22,6 +28,7 @@
reconnecting = False
reconnect_attempts = 0
output_dir = ""
win = None
Copy link

Copilot AI Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Using a global variable 'win' to manage the LossViewer instance can make the code harder to maintain; consider encapsulating this state within a class or passing it as a parameter to improve modularity.

Copilot uses AI. Check for mistakes.


async def clean_exit(pc: RTCPeerConnection, websocket: ClientConnection):
Expand Down Expand Up @@ -155,7 +162,10 @@ async def run_client(
port_number: str,
file_path: str = None,
CLI: bool = True,
output_dir: str = ""
output_dir: str = "",
config_filename: TrainingJobConfig = None,
cfg_head_name: str = None,
loss_viewer: LossViewer = None # MUST INITIALIZE LOSS VIEWER FOR WINDOW
):
"""Sends initial SDP offer to worker peer and establishes both connection & datachannel to be used by both parties.

Expand All @@ -172,6 +182,7 @@ async def run_client(

# Initialize global variables
global reconnect_attempts
global win
output_dir = output_dir

# Initalize peer connection and data channel.
Expand All @@ -180,6 +191,11 @@ async def run_client(
channel = pc.createDataChannel("my-data-channel")
logging.info("channel(%s) %s" % (channel.label, "created by local party."))

# Initialize LossViewer RTC data channel event handlers .
logging.info("Setting up RTC data channel for LossViewer...")
loss_viewer.set_rtc_channel(channel)
win = loss_viewer

async def keep_ice_alive(channel: RTCDataChannel):
"""Sends periodic keep-alive messages to the worker peer to maintain the connection.

Expand Down Expand Up @@ -321,7 +337,12 @@ async def on_channel_open():
# Initiate keep-alive task.
asyncio.create_task(keep_ice_alive(channel))
logging.info(f"{channel.label} is open")


# Setup monitor window for progress reports.
# zmq_ports = dict()
# zmq_ports["controller_port"] = 9000
# zmq_ports["publish_port"] = 9001

# Prompt for messages or file upload.
if CLI:
await send_client_messages()
Expand All @@ -343,6 +364,7 @@ async def on_message(message):
logging.info(f"Client received: {message}")
global received_files
global output_dir
global win

# Handle string and bytes messages differently.
if isinstance(message, str):
Expand All @@ -353,7 +375,7 @@ async def on_message(message):
if message == "END_OF_FILE":
# File transfer complete, save to disk.
file_name, file_data = list(received_files.items())[0]

try:
os.makedirs(output_dir, exist_ok=True)
file_path = os.path.join(output_dir, file_name)
Expand All @@ -368,13 +390,48 @@ async def on_message(message):

received_files.clear()

# Update monitor window with file transfer and training completion.
win.close()

if CLI:
# Prompt for next message
logging.info("File transfer complete. Enter next message:")
await send_client_messages()
else:
await clean_exit(pc, websocket)

elif "PROGRESS_REPORT::" in message:
# Progress report received from worker.
logging.info(message)
_, progress = message.split("PROGRESS_REPORT::", 1)

# Update LossViewer window with received progress report.
if win:
QtCore.QTimer.singleShot(0, partial(win._check_messages, rtc_msg=progress))
# win._check_messages(
# # Progress should be result from jsonpickle.decode(msg_str)
# rtc_msg=progress
# )
else:
logging.info(f"No monitor window available! win is {win}")

# print("Resetting monitor window.")
# plateau_patience = config_info.optimization.early_stopping.plateau_patience
# plateau_min_delta = config_info.optimization.early_stopping.plateau_min_delta
# win.reset(
# what=str(model_type),
# plateau_patience=plateau_patience,
# plateau_min_delta=plateau_min_delta,
# )
# win.setWindowTitle(f"Training Model - {str(model_type)}")
# win.set_message(f"Preparing to run training...")
# if save_viz:
# viz_window = QtImageDirectoryWidget.make_training_vizualizer(
# job.outputs.run_path
# )
# viz_window.move(win.x() + win.width() + 20, win.y())
# win.on_epoch.connect(viz_window.poll)

elif "FILE_META::" in message:
# Metadata received (file name & size)
_, meta = message.split("FILE_META::", 1)
Expand All @@ -384,13 +441,34 @@ async def on_message(message):
received_files[file_name] = bytearray() # Initialize as bytearray
logging.info(f"File name received: {file_name}, of size {file_size}, saving to {output_dir}")

elif "ZMQ_CTRL::" in message:
# ZMQ control message received.
_, zmq_ctrl = message.split("ZMQ_CTRL::", 1)

if zmq_ctrl == "STOP":
win.reset()

else:
logging.info(f"Worker sent: {message}")

elif isinstance(message, bytes):
if message == b"KEEP_ALIVE":
logging.info("Keep alive message received.")
return

elif b"PROGRESS_REPORT::" in message:
# Progress report received from worker.
logging.info(message.decode())
_, progress = message.decode().split("PROGRESS_REPORT::", 1)

# Update LossViewer window with received progress report.
if win:
win._check_messages(
# Progress should be result from jsonpickle.decode(msg_str)
rtc_msg=progress
)
else:
logging.info(f"No monitor window available! win is {win}")

file_name = list(received_files.keys())[0]
if file_name not in received_files:
Expand Down
Loading