Skip to content

Commit 4e2f03f

Browse files
refactor(status_display): enhance job status update mechanism
- Updated the JobStatusDisplay class to store the last response, allowing for automatic refresh of job status when no new data is provided. - Modified the update method to handle optional parameters, improving flexibility in status updates. - Streamlined the RemoteBackend class to utilize the new update mechanism for refreshing status displays during timeout events.
1 parent eb58297 commit 4e2f03f

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

src/nnsight/intervention/backends/remote.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,10 @@ def __init__(self, enabled: bool = True, verbose: bool = False):
7272
self.enabled = enabled
7373
self.verbose = verbose
7474
self.status_start_time: Optional[float] = None
75-
self.job_id: Optional[str] = None
7675
self.spinner_idx = 0
77-
self.last_status = None
76+
self.last_response: Optional[Tuple[str, str, str]] = (
77+
None # (job_id, status, description)
78+
)
7879
self._line_written = False
7980
self._notebook_display_id: Optional[str] = None
8081

@@ -115,17 +116,27 @@ def _get_spinner(self) -> str:
115116
self.spinner_idx += 1
116117
return spinner
117118

118-
def update(self, job_id: str, status_name: str, description: str = ""):
119+
def update(self, job_id: str = "", status_name: str = "", description: str = ""):
119120
"""Update the status display on a single line."""
120121
if not self.enabled:
121122
return
122123

123-
status_changed = status_name != self.last_status
124+
# Use last response values if not provided (for refresh calls)
125+
if not job_id and self.last_response:
126+
job_id, status_name, description = self.last_response
127+
128+
if not job_id:
129+
return
130+
131+
last_status = self.last_response[1] if self.last_response else None
132+
status_changed = status_name != last_status
124133

125134
# Reset timer when status changes
126135
if status_changed:
127136
self.status_start_time = time.time()
128-
self.job_id = job_id
137+
138+
# Store the response
139+
self.last_response = (job_id, status_name, description)
129140

130141
icon, color = self._get_status_style(status_name)
131142
elapsed = self._format_elapsed()
@@ -157,7 +168,6 @@ def update(self, job_id: str, status_name: str, description: str = ""):
157168
self._display(status_text, status_changed, is_terminal)
158169

159170
self._line_written = True
160-
self.last_status = status_name
161171

162172
def _display(self, text: str, status_changed: bool, is_terminal: bool):
163173
"""Display text, handling terminal vs notebook environments."""
@@ -565,10 +575,7 @@ def blocking_request(self, tracer: Tracer) -> Optional[RESULT]:
565575
response = sio.receive(timeout=timeout)[1]
566576
except socketio.exceptions.TimeoutError:
567577
# Refresh the status display to update spinner and elapsed time
568-
if self.job_id and self.job_status:
569-
self.status_display.update(
570-
self.job_id, self.job_status.name
571-
)
578+
self.status_display.update()
572579
continue
573580

574581
# Convert to pydantic object.
@@ -630,10 +637,7 @@ async def async_request(self, tracer: Tracer) -> Optional[RESULT]:
630637
response = (await sio.receive(timeout=timeout))[1]
631638
except socketio.exceptions.TimeoutError:
632639
# Refresh the status display to update spinner and elapsed time
633-
if self.job_id and self.job_status:
634-
self.status_display.update(
635-
self.job_id, self.job_status.name
636-
)
640+
self.status_display.update()
637641
continue
638642

639643
# Convert to pydantic object.

0 commit comments

Comments
 (0)