|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
| 9 | +import logging |
| 10 | + |
9 | 11 | from math import ceil |
10 | 12 | from typing import Any, cast, Mapping |
11 | 13 |
|
|
53 | 55 | from sqlalchemy.orm import defaultload, joinedload, lazyload, noload |
54 | 56 | from sqlalchemy.orm.exc import DetachedInstanceError |
55 | 57 |
|
| 58 | +logger: logging.Logger = logging.getLogger(__name__) |
| 59 | + |
56 | 60 |
|
57 | 61 | # ---------------------------- Loading `Experiment`. --------------------------- |
58 | 62 |
|
@@ -537,21 +541,32 @@ def _load_generation_strategy_by_id( |
537 | 541 | def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> int | None: |
538 | 542 | """Get DB ID of the generation strategy, associated with the experiment |
539 | 543 | with the given name if its in DB, return None otherwise. |
| 544 | +
|
| 545 | + If multiple generation strategies are associated with the experiment, |
| 546 | + returns the latest one (highest DB ID). |
540 | 547 | """ |
541 | 548 | exp_sqa_class = decoder.config.class_to_sqa_class[Experiment] |
542 | 549 | gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy] |
543 | 550 | with session_scope() as session: |
544 | | - sqa_gs_id = ( |
| 551 | + sqa_gs_ids = ( |
545 | 552 | session.query(gs_sqa_class.id) # pyre-ignore[16] |
546 | 553 | .join(exp_sqa_class.generation_strategy) # pyre-ignore[16] |
547 | 554 | # pyre-fixme[16]: `SQABase` has no attribute `name`. |
548 | 555 | .filter(exp_sqa_class.name == experiment_name) |
549 | | - .one_or_none() |
| 556 | + .order_by(gs_sqa_class.id.desc()) |
| 557 | + .all() |
550 | 558 | ) |
551 | 559 |
|
552 | | - if sqa_gs_id is None: |
| 560 | + if not sqa_gs_ids: |
553 | 561 | return None |
554 | | - return sqa_gs_id[0] |
| 562 | + |
| 563 | + if len(sqa_gs_ids) > 1: |
| 564 | + logger.warning( |
| 565 | + f"Found {len(sqa_gs_ids)} generation strategies for experiment " |
| 566 | + f"{experiment_name}. Loading the latest one (id={sqa_gs_ids[0][0]})." |
| 567 | + ) |
| 568 | + |
| 569 | + return sqa_gs_ids[0][0] |
555 | 570 |
|
556 | 571 |
|
557 | 572 | def get_generation_strategy_sqa( |
|
0 commit comments