Skip to content

Commit 06c7093

Browse files
authored
Always use environment variables from add time + minor changes (#140)
* store environment variables at add time; made a bunch of things faster; better logging * do not call scancel if the job list is empty * optimize hash query, use multiprocessing for uploading src files * remove redundant typing * proper config updates in src reload * use multithreading instead of multiprocessing * cache reading file configs * make sure that args are hashable * fix setting experiments * do not cache files from the venv
1 parent 89e1eba commit 06c7093

File tree

15 files changed

+341
-139
lines changed

15 files changed

+341
-139
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ repos:
6060
- setuptools>=69.2.0
6161
- importlib_resources>=5.7.0
6262
- typing_extensions>=4.10
63+
- deepdiff>=7.0.1
6364
- ruff>=0.6.1
6465
- pytest>=8.3.2
6566
- pre-commit>=3.8.0

examples/example_config.yaml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,25 @@ seml:
6060
description: An example configuration.
6161

6262
slurm:
63+
- experiments_per_job: 1
64+
sbatch_options:
65+
gres: gpu:0
66+
mem: 1G
67+
cpus-per-task: 1
68+
time: 0-08:00
69+
partition: cpu_all,cpu_large
6370
- experiments_per_job: 4
6471
sbatch_options:
6572
gres: gpu:1 # num GPUs
66-
mem: 16G # memory
67-
cpus-per-task: 2 # num cores
73+
mem: 1G # memory
74+
cpus-per-task: 1 # num cores
6875
time: 0-08:00 # max time, D-HH:MM
6976
partition: gpu_gtx1080
7077
- experiments_per_job: 16
7178
sbatch_options:
7279
gres: gpu:1 # num GPUs
73-
mem: 16G # memory
74-
cpus-per-task: 2 # num cores
80+
mem: 1G # memory
81+
cpus-per-task: 1 # num cores
7582
time: 0-08:00 # max time, D-HH:MM
7683
partition: gpu_a100
7784

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"setuptools>=69.2.0",
3636
"importlib_resources>=5.7.0",
3737
"typing_extensions>=4.10",
38+
"deepdiff>=7.0.1",
3839
]
3940

4041
[project.optional-dependencies]

src/seml/commands/add.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import copy
4-
import datetime
54
import logging
65
import os
76
from typing import TYPE_CHECKING, Any, Dict, cast
@@ -16,6 +15,7 @@
1615
read_config,
1716
remove_duplicates,
1817
remove_prepended_dashes,
18+
requires_interpolation,
1919
resolve_configs,
2020
resolve_interpolations,
2121
)
@@ -29,6 +29,7 @@
2929
remove_keys_from_nested,
3030
to_typeddict,
3131
unflatten,
32+
utcnow,
3233
)
3334
from seml.utils.errors import ConfigError
3435

@@ -125,7 +126,7 @@ def add_configs(
125126
**{
126127
'_id': start_id + idx,
127128
'status': States.STAGED[0],
128-
'add_time': datetime.datetime.utcnow(),
129+
'add_time': utcnow(),
129130
},
130131
},
131132
)
@@ -134,11 +135,13 @@ def add_configs(
134135
if description is not None:
135136
for db_dict in documents:
136137
db_dict['seml']['description'] = description
137-
if resolve_descriptions:
138-
for db_dict in documents:
139-
if 'description' in db_dict['seml']:
138+
# If description is not supplied via CLI, it will already be resolved.
139+
if resolve_descriptions and requires_interpolation(
140+
{'description': description}, ['description']
141+
):
142+
for db_dict in documents:
140143
db_dict['seml']['description'] = resolve_description(
141-
db_dict['seml']['description'], db_dict
144+
db_dict['seml'].get('description', ''), db_dict
142145
)
143146

144147
collection.insert_many(documents)
@@ -324,7 +327,6 @@ def add_config_file(
324327
}
325328
),
326329
'config_unresolved': config_unresolved,
327-
'seml': seml_config,
328330
},
329331
)
330332
for config, config_unresolved in zip(configs, configs_unresolved)
@@ -342,7 +344,7 @@ def add_config_file(
342344
for document in documents:
343345
document['config_hash'] = make_hash(
344346
document['config'],
345-
config_get_exclude_keys(document['config'], document['config_unresolved']),
347+
config_get_exclude_keys(document['config_unresolved']),
346348
)
347349

348350
if not force_duplicates:

src/seml/commands/manage.py

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -402,13 +402,16 @@ def delete_experiments(
402402
"""
403403
from seml.console import prompt
404404

405+
collection = get_collection(db_collection_name)
405406
# Before deleting, we should first cancel the experiments that are still running.
406407
if cancel:
407408
cancel_states = set(States.PENDING + States.RUNNING)
408409
if filter_states is not None and len(filter_states) > 0:
409410
cancel_states = cancel_states.intersection(filter_states)
410411

411-
if len(cancel_states) > 0:
412+
if len(cancel_states) > 0 and collection.find_one(
413+
build_filter_dict(cancel_states, batch_id, filter_dict, sacred_id)
414+
):
412415
cancel_experiments(
413416
db_collection_name,
414417
sacred_id,
@@ -420,12 +423,9 @@ def delete_experiments(
420423
wait=True,
421424
)
422425

423-
collection = get_collection(db_collection_name)
424426
experiment_files_to_delete = []
425427

426-
filter_dict = build_filter_dict(
427-
filter_states, batch_id, filter_dict, sacred_id=sacred_id
428-
)
428+
filter_dict = build_filter_dict(filter_states, batch_id, filter_dict, sacred_id)
429429
ndelete = collection.count_documents(filter_dict)
430430
if sacred_id is not None and ndelete == 0:
431431
raise MongoDBError(f'No experiment found with ID {sacred_id}.')
@@ -560,7 +560,12 @@ def get_experiment_reset_op(exp: ExperimentDoc):
560560
]
561561

562562
# Clean up SEML dictionary
563-
keep_seml = {'source_files', 'working_dir', SETTINGS.SEML_CONFIG_VALUE_VERSION}
563+
keep_seml = {
564+
'source_files',
565+
'working_dir',
566+
'env',
567+
SETTINGS.SEML_CONFIG_VALUE_VERSION,
568+
}
564569
keep_seml.update(SETTINGS.VALID_SEML_CONFIG_VALUES)
565570
seml_keys = set(exp['seml'].keys())
566571
for key in seml_keys - keep_seml:
@@ -627,6 +632,9 @@ def reset_experiments(
627632
exps = collection.find(filter_dict)
628633
if sacred_id is not None and nreset == 0:
629634
raise MongoDBError(f'No experiment found with ID {sacred_id}.')
635+
if nreset == 0:
636+
logging.info('No experiments to reset.')
637+
return
630638

631639
logging.info(f'Resetting the state of {nreset} experiment{s_if(nreset)}.')
632640
if nreset >= SETTINGS.CONFIRM_THRESHOLD.RESET:
@@ -761,7 +769,8 @@ def reload_sources(
761769
"""
762770
from importlib.metadata import version
763771

764-
import gridfs
772+
from bson import ObjectId
773+
from deepdiff import DeepDiff
765774
from pymongo import UpdateOne
766775

767776
from seml.console import prompt
@@ -774,7 +783,15 @@ def reload_sources(
774783
filter_dict = {}
775784
db_results = list(
776785
collection.find(
777-
filter_dict, {'batch_id', 'seml', 'config', 'status', 'config_unresolved'}
786+
filter_dict,
787+
{
788+
'batch_id',
789+
'seml',
790+
'config',
791+
'status',
792+
'config_unresolved',
793+
'config_hash',
794+
},
778795
)
779796
)
780797
id_to_document: dict[int, list[ExperimentDoc]] = {}
@@ -845,29 +862,40 @@ def reload_sources(
845862
)
846863
]
847864

848-
result = collection.bulk_write(
849-
[
850-
UpdateOne(
851-
{'_id': document['_id']},
852-
{
853-
'$set': {
854-
'config': document['config'],
855-
'config_unresolved': document['config_unresolved'],
856-
'config_hash': make_hash(
857-
document['config'],
858-
config_get_exclude_keys(
859-
document['config'], document['config_unresolved']
860-
),
861-
),
862-
}
863-
},
865+
# determine which documents to udpate
866+
updates = []
867+
for old_doc, new_doc in zip(documents, new_documents):
868+
use_hash = 'config_hash' in old_doc
869+
# these config fields are populated if the experiment ran
870+
runtime_fields = {
871+
k: old_doc['config'][k]
872+
for k in ['db_collection', 'overwrite', 'seed']
873+
if k in old_doc['config']
874+
}
875+
new = dict(
876+
config=new_doc['config'] | runtime_fields,
877+
config_unresolved=new_doc['config_unresolved'],
878+
)
879+
# compare new to old config
880+
if use_hash:
881+
new['config_hash'] = make_hash(
882+
new_doc['config'],
883+
config_get_exclude_keys(new_doc['config_unresolved']),
864884
)
865-
for document in new_documents
866-
]
867-
)
868-
logging.info(
869-
f'Batch {batch_id}: Resolved configurations of {result.matched_count} experiments against new source files ({result.modified_count} changed).'
870-
)
885+
update = new['config_hash'] != old_doc['config_hash']
886+
else:
887+
diff = DeepDiff(new['config'], old_doc['config'], ignore_order=True)
888+
update = bool(diff)
889+
# Create mongodb update
890+
if update:
891+
updates.append(UpdateOne({'_id': old_doc['_id']}, {'$set': new}))
892+
if len(updates) > 0:
893+
result = collection.bulk_write(updates)
894+
logging.info(
895+
f'Batch {batch_id}: Resolved configurations of {result.matched_count} experiments against new source files ({result.modified_count} changed).'
896+
)
897+
else:
898+
logging.info(f'Batch {batch_id}: No experiment configurations changed.')
871899

872900
# Check whether the configurations aligns with the current source code
873901
check_config(
@@ -879,7 +907,6 @@ def reload_sources(
879907

880908
# Find the currently used source files
881909
db = collection.database
882-
fs = gridfs.GridFS(db)
883910
fs_filter_dict = {
884911
'metadata.batch_id': batch_id,
885912
'metadata.collection_name': f'{collection.name}',
@@ -915,8 +942,7 @@ def reload_sources(
915942
except Exception as e:
916943
logging.error(f'Batch {batch_id}: Failed to set new source files.')
917944
# Delete new source files from DB
918-
for to_delete in source_files:
919-
fs.delete(to_delete[1])
945+
delete_files(db, [x[1] for x in source_files])
920946
raise e
921947

922948
# Delete the old source files
@@ -927,10 +953,10 @@ def reload_sources(
927953
'metadata.deprecated': True,
928954
}
929955
source_files_old = [
930-
x['_id'] for x in db['fs.files'].find(fs_filter_dict, {'_id'})
956+
cast(ObjectId, x['_id'])
957+
for x in db['fs.files'].find(fs_filter_dict, {'_id'})
931958
]
932-
for to_delete in source_files_old:
933-
fs.delete(to_delete)
959+
delete_files(db, source_files_old)
934960

935961

936962
def detect_duplicates(

0 commit comments

Comments
 (0)