Skip to content

Commit 75cb321

Browse files
committed
Add MK2.1 lossless share support
1 parent baa2810 commit 75cb321

10 files changed

Lines changed: 738 additions & 69 deletions

File tree

cli.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from calendar_gen_v2 import _load_yearly_budget
1515
from engines.base import ScheduleInput
1616
from engines.engine_mk1 import EngineMK1
17-
from engines.engine_mk2 import EngineMK2
17+
from engines.engine_mk2 import EngineMK2, EngineMK21
1818
from rigs.simple_rig import SimpleRig
1919
from rigs.workforce_rig import WorkforceRig
2020

@@ -24,6 +24,8 @@ def _build_engine(engine_name: str) -> object:
2424
return EngineMK1()
2525
if engine_name == "mk2":
2626
return EngineMK2()
27+
if engine_name == "mk2_1":
28+
return EngineMK21()
2729
raise ValueError(f"Unknown engine '{engine_name}'")
2830

2931

@@ -96,7 +98,7 @@ def build_parser() -> argparse.ArgumentParser:
9698
parser = argparse.ArgumentParser(description="Unified Wyrd Engine CLI")
9799
parser.add_argument(
98100
"--engine",
99-
choices=["mk1", "mk2"],
101+
choices=["mk1", "mk2", "mk2_1"],
100102
required=True,
101103
help="Select which engine implementation to use",
102104
)
@@ -157,8 +159,10 @@ def main(argv: Optional[Sequence[str]] = None) -> None:
157159
return
158160

159161
if args.rig in {"calendar", "workforce"}:
160-
if args.engine != "mk2":
161-
parser.error(f"The {args.rig} rig requires the mk2 engine")
162+
if args.engine not in {"mk2", "mk2_1"}:
163+
parser.error(
164+
f"The {args.rig} rig requires an MK2-series engine (mk2 or mk2_1)"
165+
)
162166
_run_workforce(args)
163167
return
164168

engines/engine_mk2.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
__all__ = [
4040
"EngineMK2",
41+
"EngineMK21",
4142
"DayPlan",
4243
"apply_micro_jitter",
4344
"apply_seasonal_modifiers",
@@ -89,6 +90,7 @@ def __init__(
8990
Callable[[PersonProfile, date, "UniqueDay"], Optional[List[Activity]]]
9091
] = None,
9192
validator: Optional[Callable[[Dict[str, List[Activity]]], List[object]]] = None,
93+
engine_version: str = "mk2",
9294
) -> None:
9395
self._profile_factory = {
9496
"office": (create_office_worker, DEFAULT_TEMPLATES),
@@ -103,12 +105,17 @@ def __init__(
103105
[PersonProfile, date, "UniqueDay"], Optional[List[Activity]]
104106
]
105107
self._validator: Callable[[Dict[str, List[Activity]]], List[object]]
108+
self._engine_version = engine_version or "mk2"
106109
self.set_friction_generator(friction_generator or generate_daily_friction)
107110
self.set_unique_schedule_generator(
108111
unique_schedule_generator or generate_unique_day_schedule
109112
)
110113
self.set_validator(validator or validate_week)
111114

115+
@property
116+
def engine_version(self) -> str:
117+
return self._engine_version
118+
112119
def set_calendar_provider(self, provider: CalendarProvider) -> None:
113120
"""Replace the calendar provider used by the engine."""
114121

@@ -689,18 +696,21 @@ def generate_complete_week(
689696
if debug_trace is not None:
690697
debug_trace["weekly_totals_from_events"] = weekly_totals_minutes
691698

699+
metadata = {
700+
"total_events": len(events_payload),
701+
"issue_count": len(issues),
702+
"summary_hours": summary_hours,
703+
"compression": compression_metadata,
704+
"day_types": {plan.date.isoformat(): plan.day_type for plan in week_plans},
705+
}
706+
metadata["engine_version"] = self._engine_version
707+
692708
result: Dict[str, Any] = {
693709
"person": profile.name,
694710
"week_start": start_date.isoformat(),
695711
"events": events_payload,
696712
"issues": [asdict(issue) for issue in issues],
697-
"metadata": {
698-
"total_events": len(events_payload),
699-
"issue_count": len(issues),
700-
"summary_hours": summary_hours,
701-
"compression": compression_metadata,
702-
"day_types": {plan.date.isoformat(): plan.day_type for plan in week_plans},
703-
},
713+
"metadata": metadata,
704714
}
705715

706716
if debug and debug_trace is not None:
@@ -716,6 +726,28 @@ def select_profile(self, archetype: str) -> Tuple[PersonProfile, Dict[str, Activ
716726
return factory(), templates
717727

718728

729+
class EngineMK21(EngineMK2):
730+
"""MK2.1 variant that enables lossless share aggregation."""
731+
732+
def __init__(
733+
self,
734+
calendar_provider: Optional[CalendarProvider] = None,
735+
*,
736+
friction_generator: Optional[Callable[[int, float, float], float]] = None,
737+
unique_schedule_generator: Optional[
738+
Callable[[PersonProfile, date, "UniqueDay"], Optional[List[Activity]]]
739+
] = None,
740+
validator: Optional[Callable[[Dict[str, List[Activity]]], List[object]]] = None,
741+
) -> None:
742+
super().__init__(
743+
calendar_provider=calendar_provider,
744+
friction_generator=friction_generator,
745+
unique_schedule_generator=unique_schedule_generator,
746+
validator=validator,
747+
engine_version="mk2_1",
748+
)
749+
750+
719751
def normalize_mk2_events(
720752
events: Iterable[Mapping[str, object]], *, week_start: Optional[date] = None
721753
) -> List[Dict[str, Any]]:

engines/web_adapter.py

Lines changed: 102 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from engines.base import ScheduleInput
1010
from engines.engine_mk1 import EngineMK1
11-
from engines.engine_mk2 import EngineMK2
11+
from engines.engine_mk2 import EngineMK2, EngineMK21
1212
from modules.unique_events import UniqueDay
1313
from rigs.simple_rig import SimpleRig
1414
from rigs.workforce_rig import WorkforceRig
@@ -175,12 +175,16 @@ def _coerce_start_date(value: Any) -> Optional[date]:
175175
return None
176176

177177

178-
def _ensure_schema(payload: MutableMapping[str, Any], *, rig: str, seed: int, archetype: str) -> SchemaPayload:
178+
def _ensure_schema(
179+
payload: MutableMapping[str, Any], *, rig: str, seed: int, archetype: str, engine_version: Optional[str] = None
180+
) -> SchemaPayload:
179181
metadata = dict(payload.get("metadata", {}))
180182
metadata.setdefault("summary_hours", {})
183+
resolved_engine = engine_version or ("mk2" if rig in {"calendar", "workforce"} else "mk1")
181184
metadata.update(
182185
{
183-
"engine": "mk2" if rig in {"calendar", "workforce"} else "mk1",
186+
"engine": resolved_engine,
187+
"engine_version": resolved_engine,
184188
"rig": rig,
185189
"seed": seed,
186190
"archetype": archetype,
@@ -243,6 +247,8 @@ def _convert_events(events: Iterable[Mapping[str, Any]]) -> Iterable[Dict[str, A
243247

244248
_MK2_ENGINE = EngineMK2()
245249
_MK2_RIG = WorkforceRig(engine=_MK2_ENGINE)
250+
_MK2_1_ENGINE = EngineMK21()
251+
_MK2_1_RIG = WorkforceRig(engine=_MK2_1_ENGINE)
246252

247253

248254
def mk1_run_web(archetype: str, week_start: Optional[str], seed: Any) -> SchemaPayload:
@@ -278,49 +284,34 @@ def mk1_run_web(archetype: str, week_start: Optional[str], seed: Any) -> SchemaP
278284
},
279285
}
280286

281-
return _ensure_schema(payload, rig="default", seed=seed_value, archetype=archetype_key or "office")
282-
283-
284-
def mk2_run_calendar_web(
285-
archetype: str, week_start: Optional[str], seed: Any, debug: bool = False
286-
) -> SchemaPayload:
287-
archetype_key = str(archetype or "office").strip().lower()
288-
seed_value = _coerce_seed(seed)
289-
start_date = _coerce_start_date(week_start) or date.today()
290-
291-
profile, templates = _MK2_RIG.select_profile(archetype_key)
292-
result = _MK2_RIG.generate_complete_week(
293-
profile, start_date, seed_value, templates, None, debug=debug
287+
return _ensure_schema(
288+
payload,
289+
rig="default",
290+
seed=seed_value,
291+
archetype=archetype_key or "office",
292+
engine_version="mk1",
294293
)
295294

296-
payload: MutableMapping[str, Any] = dict(result)
297-
payload.setdefault("issues", [])
298-
payload.setdefault("events", [])
299-
payload.setdefault("week_start", start_date.isoformat())
300-
payload["person"] = profile.name
301-
payload["metadata"] = {
302-
**payload.get("metadata", {}),
303-
"profile": profile.name,
304-
}
305-
306-
return _ensure_schema(payload, rig="calendar", seed=seed_value, archetype=archetype_key)
307295

308-
309-
def mk2_run_workforce_web(
296+
def _run_mk2_variant(
297+
rig_instance: WorkforceRig,
310298
archetype: str,
311299
week_start: Optional[str],
312300
seed: Any,
313-
yearly_budget: Optional[Mapping[str, Any]],
301+
*,
302+
engine_version: str,
303+
rig_label: str,
304+
yearly_budget: Optional[Mapping[str, Any]] = None,
314305
debug: bool = False,
315306
) -> SchemaPayload:
316307
archetype_key = str(archetype or "office").strip().lower()
317308
seed_value = _coerce_seed(seed)
318309
start_date = _coerce_start_date(week_start) or date.today()
319310

320-
profile, templates = _MK2_RIG.select_profile(archetype_key)
321-
budget = _build_yearly_budget(yearly_budget)
311+
profile, templates = rig_instance.select_profile(archetype_key)
312+
budget = _build_yearly_budget(yearly_budget) if yearly_budget is not None else None
322313

323-
result = _MK2_RIG.generate_complete_week(
314+
result = rig_instance.generate_complete_week(
324315
profile, start_date, seed_value, templates, budget, debug=debug
325316
)
326317

@@ -331,6 +322,7 @@ def mk2_run_workforce_web(
331322
payload["person"] = profile.name
332323
metadata = dict(payload.get("metadata", {}))
333324
metadata["profile"] = profile.name
325+
metadata["engine_version"] = engine_version
334326
if budget is not None:
335327
metadata["yearly_budget"] = {
336328
"person_id": budget.person_id,
@@ -341,11 +333,87 @@ def mk2_run_workforce_web(
341333
}
342334
payload["metadata"] = metadata
343335

344-
return _ensure_schema(payload, rig="workforce", seed=seed_value, archetype=archetype_key)
336+
return _ensure_schema(
337+
payload,
338+
rig=rig_label,
339+
seed=seed_value,
340+
archetype=archetype_key,
341+
engine_version=engine_version,
342+
)
343+
344+
345+
def mk2_run_calendar_web(
346+
archetype: str, week_start: Optional[str], seed: Any, debug: bool = False
347+
) -> SchemaPayload:
348+
return _run_mk2_variant(
349+
_MK2_RIG,
350+
archetype,
351+
week_start,
352+
seed,
353+
engine_version="mk2",
354+
rig_label="calendar",
355+
yearly_budget=None,
356+
debug=debug,
357+
)
358+
359+
360+
def mk2_run_workforce_web(
361+
archetype: str,
362+
week_start: Optional[str],
363+
seed: Any,
364+
yearly_budget: Optional[Mapping[str, Any]],
365+
debug: bool = False,
366+
) -> SchemaPayload:
367+
return _run_mk2_variant(
368+
_MK2_RIG,
369+
archetype,
370+
week_start,
371+
seed,
372+
engine_version="mk2",
373+
rig_label="workforce",
374+
yearly_budget=yearly_budget,
375+
debug=debug,
376+
)
377+
378+
379+
def mk2_1_run_calendar_web(
380+
archetype: str, week_start: Optional[str], seed: Any, debug: bool = False
381+
) -> SchemaPayload:
382+
return _run_mk2_variant(
383+
_MK2_1_RIG,
384+
archetype,
385+
week_start,
386+
seed,
387+
engine_version="mk2_1",
388+
rig_label="calendar",
389+
yearly_budget=None,
390+
debug=debug,
391+
)
392+
393+
394+
def mk2_1_run_workforce_web(
395+
archetype: str,
396+
week_start: Optional[str],
397+
seed: Any,
398+
yearly_budget: Optional[Mapping[str, Any]],
399+
debug: bool = False,
400+
) -> SchemaPayload:
401+
return _run_mk2_variant(
402+
_MK2_1_RIG,
403+
archetype,
404+
week_start,
405+
seed,
406+
engine_version="mk2_1",
407+
rig_label="workforce",
408+
yearly_budget=yearly_budget,
409+
debug=debug,
410+
)
345411

346412

347413
__all__ = [
348414
"mk1_run_web",
349415
"mk2_run_calendar_web",
350416
"mk2_run_workforce_web",
417+
"mk2_1_run_calendar_web",
418+
"mk2_1_run_workforce_web",
351419
]

tests/test_cli_parity.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
ROOT = Path(__file__).resolve().parents[1]
2020

2121

22+
def _override_metadata_engine(payload: dict, engine_value: str) -> dict:
23+
"""Return a deep copy of payload with metadata engine fields normalised."""
24+
25+
clone = json.loads(json.dumps(payload))
26+
metadata = clone.setdefault("metadata", {})
27+
if engine_value:
28+
metadata["engine_version"] = engine_value
29+
return clone
30+
31+
2232
@pytest.mark.parametrize(
2333
"cli_args, legacy_runner",
2434
[
@@ -48,6 +58,21 @@
4858
],
4959
"run_calendar_gen_v2",
5060
),
61+
(
62+
[
63+
"--engine",
64+
"mk2_1",
65+
"--rig",
66+
"workforce",
67+
"--archetype",
68+
"office",
69+
"--seed",
70+
"7",
71+
"--start-date",
72+
"2025-01-06",
73+
],
74+
"run_calendar_gen_v2",
75+
),
5176
],
5277
)
5378
def test_cli_matches_legacy(tmp_path: Path, cli_args, legacy_runner) -> None:
@@ -98,7 +123,15 @@ def run_calendar_gen_v2_legacy() -> str:
98123

99124
cli_data = json.loads(cli_output.read_text())
100125
legacy_data = json.loads(legacy_output.read_text())
101-
assert cli_data == legacy_data
126+
engine_choice = None
127+
if "--engine" in cli_args:
128+
engine_index = cli_args.index("--engine")
129+
if engine_index + 1 < len(cli_args):
130+
engine_choice = cli_args[engine_index + 1]
131+
comparison_target = legacy_data
132+
if engine_choice == "mk2_1":
133+
comparison_target = _override_metadata_engine(legacy_data, engine_choice)
134+
assert cli_data == comparison_target
102135

103136
def normalise_output(text: str, output_path: Path) -> str:
104137
return text.replace(str(output_path), "<OUTPUT>")

0 commit comments

Comments
 (0)