Skip to content

Commit 2982791

Browse files
committed
Remove classmethod refactoring + make cell tracking optional
1 parent 8c32ef2 commit 2982791

File tree

4 files changed

+51
-32
lines changed

4 files changed

+51
-32
lines changed

jupyter_scheduler/executors.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class ExecutionManager(ABC):
2929
_model = None
3030
_db_session = None
3131

32-
def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]):
32+
def __init__(
33+
self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]
34+
):
3335
self.job_id = job_id
3436
self.staging_paths = staging_paths
3537
self.root_dir = root_dir
@@ -74,18 +76,16 @@ def execute(self):
7476
"""
7577
pass
7678

77-
@classmethod
7879
@abstractmethod
79-
def supported_features(cls) -> Dict[JobFeature, bool]:
80+
def supported_features(self) -> Dict[JobFeature, bool]:
8081
"""Returns a configuration of supported features
8182
by the execution engine. Implementors are expected
8283
to override this to return a dictionary of supported
8384
job creation features.
8485
"""
8586
pass
8687

87-
@classmethod
88-
def validate(cls, input_path: str) -> bool:
88+
def validate(self, input_path: str) -> bool:
8989
"""Returns True if notebook has valid metadata to execute, False otherwise"""
9090
return True
9191

@@ -134,7 +134,9 @@ def execute(self):
134134
staging_dir = os.path.dirname(self.staging_paths["input"])
135135

136136
ep = ExecutePreprocessor(
137-
kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir
137+
kernel_name=nb.metadata.kernelspec["name"],
138+
store_widget_state=True,
139+
cwd=staging_dir,
138140
)
139141

140142
if self.supported_features().get(JobFeature.track_cell_execution, False):
@@ -173,10 +175,14 @@ def add_side_effects_files(self, staging_dir: str):
173175
if new_files_set:
174176
with self.db_session() as session:
175177
current_packaged_files_set = set(
176-
session.query(Job.packaged_files).filter(Job.job_id == self.job_id).scalar()
178+
session.query(Job.packaged_files)
179+
.filter(Job.job_id == self.job_id)
180+
.scalar()
177181
or []
178182
)
179-
updated_packaged_files = list(current_packaged_files_set.union(new_files_set))
183+
updated_packaged_files = list(
184+
current_packaged_files_set.union(new_files_set)
185+
)
180186
session.query(Job).filter(Job.job_id == self.job_id).update(
181187
{"packaged_files": updated_packaged_files}
182188
)
@@ -186,11 +192,12 @@ def create_output_files(self, job: DescribeJob, notebook_node):
186192
for output_format in job.output_formats:
187193
cls = nbconvert.get_exporter(output_format)
188194
output, _ = cls().from_notebook_node(notebook_node)
189-
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
195+
with fsspec.open(
196+
self.staging_paths[output_format], "w", encoding="utf-8"
197+
) as f:
190198
f.write(output)
191199

192-
@classmethod
193-
def supported_features(cls) -> Dict[JobFeature, bool]:
200+
def supported_features(self) -> Dict[JobFeature, bool]:
194201
return {
195202
JobFeature.job_name: True,
196203
JobFeature.output_formats: True,
@@ -205,11 +212,10 @@ def supported_features(cls) -> Dict[JobFeature, bool]:
205212
JobFeature.output_filename_template: False,
206213
JobFeature.stop_job: True,
207214
JobFeature.delete_job: True,
208-
JobFeature.track_cell_execution: True,
215+
JobFeature.track_cell_execution: False,
209216
}
210217

211-
@classmethod
212-
def validate(cls, input_path: str) -> bool:
218+
def validate(self, input_path: str) -> bool:
213219
with open(input_path, encoding="utf-8") as f:
214220
nb = nbformat.read(f, as_version=4)
215221
try:

jupyter_scheduler/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def create_job(self, model: CreateJob) -> str:
442442
raise InputUriError(model.input_uri)
443443

444444
input_path = os.path.join(self.root_dir, model.input_uri)
445-
if not self.execution_manager_class.validate(input_path):
445+
if not self.execution_manager_class.validate(self.execution_manager_class, input_path):
446446
raise SchedulerError(
447447
"""There is no kernel associated with the notebook. Please open
448448
the notebook, select a kernel, and re-submit the job to execute.

jupyter_scheduler/tests/mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def execute(self):
1616
def process(self):
1717
pass
1818

19-
def supported_features(cls) -> Dict[JobFeature, bool]:
19+
def supported_features(self) -> Dict[JobFeature, bool]:
2020
return {
2121
JobFeature.job_name: True,
2222
JobFeature.output_formats: True,

jupyter_scheduler/tests/test_execution_manager.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def test_add_side_effects_files(
6161
assert side_effect_file_name in job.packaged_files
6262

6363

64-
def test_default_execution_manager_cell_tracking_hook():
65-
"""Test that DefaultExecutionManager sets up on_cell_executed hook when track_cell_execution is supported"""
64+
def test_default_execution_manager_cell_tracking_hook_not_set_by_default():
65+
"""Test that DefaultExecutionManager does NOT set up on_cell_executed hook when track_cell_execution is disabled by default"""
6666
job_id = "test-job-id"
6767

6868
with patch.object(DefaultExecutionManager, "model") as mock_model:
@@ -100,9 +100,8 @@ def test_default_execution_manager_cell_tracking_hook():
100100
# Verify ExecutePreprocessor was created
101101
mock_ep_class.assert_called_once()
102102

103-
# Verify on_cell_executed hook was set
104-
assert hasattr(mock_ep, "on_cell_executed")
105-
assert mock_ep.on_cell_executed is not None
103+
# Verify patching method was never called
104+
mock_model.__update_completed_cells_hook.assert_not_called()
106105

107106

108107
def test_update_completed_cells_hook():
@@ -158,8 +157,8 @@ def test_update_completed_cells_hook_database_error():
158157
# Mock db_session with error
159158
mock_db_session = MagicMock()
160159
mock_session_context = MagicMock()
161-
mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception(
162-
"DB Error"
160+
mock_session_context.query.return_value.filter.return_value.update.side_effect = (
161+
Exception("DB Error")
163162
)
164163
mock_db_session.return_value.__enter__.return_value = mock_session_context
165164
manager._db_session = mock_db_session
@@ -181,12 +180,18 @@ def test_update_completed_cells_hook_database_error():
181180

182181
def test_supported_features_includes_track_cell_execution():
183182
"""Test that DefaultExecutionManager supports track_cell_execution feature"""
184-
features = DefaultExecutionManager.supported_features()
183+
manager = DefaultExecutionManager(
184+
job_id="test-job-id",
185+
root_dir="/test",
186+
db_url="sqlite:///:memory:",
187+
staging_paths={"input": "/test/input.ipynb"},
188+
)
189+
features = manager.supported_features()
185190

186191
from jupyter_scheduler.models import JobFeature
187192

188193
assert JobFeature.track_cell_execution in features
189-
assert features[JobFeature.track_cell_execution] is True
194+
assert features[JobFeature.track_cell_execution] is False
190195

191196

192197
def test_hook_uses_correct_job_id():
@@ -233,8 +238,7 @@ def test_cell_tracking_disabled_when_feature_false():
233238

234239
# Create a custom execution manager class with track_cell_execution = False
235240
class DisabledTrackingExecutionManager(DefaultExecutionManager):
236-
@classmethod
237-
def supported_features(cls):
241+
def supported_features(self):
238242
features = super().supported_features()
239243
from jupyter_scheduler.models import JobFeature
240244

@@ -256,8 +260,12 @@ def supported_features(cls):
256260
with patch.object(DisabledTrackingExecutionManager, "model") as mock_model:
257261
with patch("jupyter_scheduler.executors.open", mock=MagicMock()):
258262
with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read:
259-
with patch.object(DisabledTrackingExecutionManager, "add_side_effects_files"):
260-
with patch.object(DisabledTrackingExecutionManager, "create_output_files"):
263+
with patch.object(
264+
DisabledTrackingExecutionManager, "add_side_effects_files"
265+
):
266+
with patch.object(
267+
DisabledTrackingExecutionManager, "create_output_files"
268+
):
261269
with patch(
262270
"jupyter_scheduler.executors.ExecutePreprocessor"
263271
) as mock_ep_class:
@@ -288,15 +296,20 @@ def test_disabled_tracking_feature_support():
288296

289297
# Create a custom execution manager class with track_cell_execution = False
290298
class DisabledTrackingExecutionManager(DefaultExecutionManager):
291-
@classmethod
292-
def supported_features(cls):
299+
def supported_features(self):
293300
features = super().supported_features()
294301
from jupyter_scheduler.models import JobFeature
295302

296303
features[JobFeature.track_cell_execution] = False
297304
return features
298305

299-
features = DisabledTrackingExecutionManager.supported_features()
306+
manager = DisabledTrackingExecutionManager(
307+
job_id="test-job-id",
308+
root_dir="/test",
309+
db_url="sqlite:///:memory:",
310+
staging_paths={"input": "/test/input.ipynb"},
311+
)
312+
features = manager.supported_features()
300313

301314
from jupyter_scheduler.models import JobFeature
302315

0 commit comments

Comments
 (0)