diff --git a/backend/app.py b/backend/app.py index dbc5928..3a1549f 100644 --- a/backend/app.py +++ b/backend/app.py @@ -17,7 +17,12 @@ from backend.mujoco_utils import compute_body_transforms, extract_model_geometry from backend.acm_processing import load_acm_trials, load_single_matfile, apply_retargeting from backend.alignment import align_acm_to_mujoco -from backend.config_io import load_stac_yaml, dump_stac_yaml, load_stac_output_h5 +from backend.config_io import ( + load_stac_yaml, + dump_stac_yaml, + dump_stac_ui_sidecar, + load_stac_output_h5, +) from backend.frame_selector import suggest_frames from backend.stac_runner import run_quick_stac @@ -194,6 +199,22 @@ async def export_config(data: dict): return PlainTextResponse(body, media_type="application/x-yaml") +@app.post("/api/export-ui-sidecar") +async def export_ui_sidecar(data: dict): + """Serialize UI-only state (skeleton editor, ...) as a sidecar YAML. + + Returns 204 when there's nothing to save, so the frontend can skip the + download. + """ + try: + body = dump_stac_ui_sidecar(data["config"]) + except Exception as e: + return JSONResponse({"error": str(e)}, status_code=500) + if body is None: + return PlainTextResponse("", status_code=204) + return PlainTextResponse(body, media_type="application/x-yaml") + + @app.post("/api/load-stac-output") async def load_stac_output(file: UploadFile = File(None), path: str = Query(None)): """Load STAC output H5 from a server-side path or an uploaded file.""" diff --git a/backend/config_io.py b/backend/config_io.py index ff5829b..3540a2a 100644 --- a/backend/config_io.py +++ b/backend/config_io.py @@ -1,5 +1,6 @@ """STAC YAML config and H5 import/export.""" from __future__ import annotations +import copy from pathlib import Path import yaml import numpy as np @@ -10,6 +11,21 @@ # (e.g. configs/model/rodent.yaml). Used to detect flat vs. wrapped shapes. _MODEL_FIELD_MARKERS = ("KEYPOINT_MODEL_PAIRS", "KP_NAMES", "MJCF_PATH") +# Fields that the UI owns and will overwrite on export. +_UI_MANAGED_FIELDS = ( + "MJCF_PATH", + "SCALE_FACTOR", + "MOCAP_SCALE_FACTOR", + "KP_NAMES", + "KEYPOINT_MODEL_PAIRS", + "KEYPOINT_INITIAL_OFFSETS", +) + + +def _is_flat(raw: dict) -> bool: + """True if `raw` looks like a flat stac-mjx model config.""" + return any(k in raw for k in _MODEL_FIELD_MARKERS) + def _extract_model_section(raw: dict) -> dict: """Return the dict containing model-level fields from a loaded YAML. @@ -20,13 +36,40 @@ def _extract_model_section(raw: dict) -> dict: the file into the `model` namespace during composition. - Wrapped: the UI's own export, where everything is nested under `model:`. """ - if any(k in raw for k in _MODEL_FIELD_MARKERS): + if _is_flat(raw): return raw return raw.get("model", {}) +def _offsets_to_yaml(offsets: dict) -> dict: + """Convert [x, y, z] offsets to space-separated strings (stac-mjx format).""" + return {kp: f"{v[0]} {v[1]} {v[2]}" for kp, v in offsets.items()} + + +def _ui_managed_fields(config: dict) -> dict: + """Build the model-level dict of fields the UI owns, in canonical order.""" + return { + "MJCF_PATH": config.get("xmlPath", ""), + "SCALE_FACTOR": config.get("scaleFactor", 0.9), + "MOCAP_SCALE_FACTOR": config.get("mocapScaleFactor", 0.01), + "KP_NAMES": config.get( + "kpNames", list(config.get("keypointModelPairs", {}).keys()) + ), + "KEYPOINT_MODEL_PAIRS": config.get("keypointModelPairs", {}), + "KEYPOINT_INITIAL_OFFSETS": _offsets_to_yaml( + config.get("keypointInitialOffsets", {}) + ), + } + + def load_stac_yaml(path: str) -> dict: - """Load STAC config YAML and return normalized dict for the UI.""" + """Load STAC config YAML and return normalized dict for the UI. + + Returns: + Dict with UI-normalized fields (keypointModelPairs, keypointInitialOffsets, + scaleFactor, mocapScaleFactor, kpNames, xmlPath) plus `_rawTemplate`: + the full parsed YAML, for template-overlay export. + """ with open(path) as f: raw = yaml.safe_load(f) or {} model = _extract_model_section(raw) @@ -46,33 +89,82 @@ def load_stac_yaml(path: str) -> dict: "mocapScaleFactor": float(model.get("MOCAP_SCALE_FACTOR", 0.01)), "kpNames": list(model.get("KP_NAMES", [])), "xmlPath": model.get("MJCF_PATH", ""), + "_rawTemplate": raw, } +def _is_empty(v) -> bool: + """Treat None and empty containers/strings as 'no UI data to contribute'.""" + if v is None: + return True + if isinstance(v, (list, dict, str)): + return len(v) == 0 + return False + + +def _overlay_onto_template(template: dict, ui_fields: dict) -> dict: + """Overlay UI-managed fields onto a template, preserving its shape. + + - Flat template → overlay at top level, preserving key order (UI fields + replace existing keys in place; new keys appended). + - Wrapped template → overlay under raw["model"]. + - UI-only sections like `skeleton_editor` are stripped. + - Empty UI values (e.g. KP_NAMES=[] when no keypoints were loaded) do not + clobber a populated template field — otherwise exporting without + loading mocap would wipe the template's keypoint list. + """ + out = copy.deepcopy(template) + out.pop("skeleton_editor", None) + + target = out if _is_flat(out) else out.setdefault("model", {}) + + for field in _UI_MANAGED_FIELDS: + value = ui_fields[field] + if _is_empty(value) and not _is_empty(target.get(field)): + continue + target[field] = value + return out + + def dump_stac_yaml(config: dict) -> str: - """Serialize UI state to STAC-compatible YAML and return it as a string.""" - offsets_str = {} - for kp, vals in config.get("keypointInitialOffsets", {}).items(): - offsets_str[kp] = f"{vals[0]} {vals[1]} {vals[2]}" - yaml_dict = { - "model": { - "MJCF_PATH": config.get("xmlPath", ""), - "SCALE_FACTOR": config.get("scaleFactor", 0.9), - "MOCAP_SCALE_FACTOR": config.get("mocapScaleFactor", 0.01), - "KP_NAMES": config.get("kpNames", list(config.get("keypointModelPairs", {}).keys())), - "KEYPOINT_MODEL_PAIRS": config.get("keypointModelPairs", {}), - "KEYPOINT_INITIAL_OFFSETS": offsets_str, - }, - } - # Include segment scales if any are non-default - segment_scales = config.get("segmentScales", {}) - if segment_scales: - non_default = {k: v for k, v in segment_scales.items() if abs(v - 1.0) > 0.001} - if non_default: - yaml_dict["skeleton_editor"] = {"segment_scales": non_default} + """Serialize UI state to STAC-compatible YAML and return it as a string. + + If `config` carries `_rawTemplate` (from a prior `load_stac_yaml`), overlay + the UI's edits onto it so fields the UI doesn't manage (N_ITERS, + ROOT_OPTIMIZATION_KEYPOINT, SITES_TO_REGULARIZE, ...) are preserved. + + Without a template, emit a UI-wrapped shape (nested under `model:`). That + shape is the UI's internal round-trip format and is NOT a drop-in + stac-mjx config — use template-overlay for that. + """ + ui_fields = _ui_managed_fields(config) + template = config.get("_rawTemplate") + if template: + yaml_dict = _overlay_onto_template(template, ui_fields) + else: + yaml_dict = {"model": dict(ui_fields)} return yaml.dump(yaml_dict, default_flow_style=False, sort_keys=False) +def dump_stac_ui_sidecar(config: dict) -> str | None: + """Serialize UI-only state (skeleton editor) to its own YAML. + + Returns None when there's nothing to save — the caller should skip the + sidecar download in that case rather than emitting an empty file. + """ + segment_scales = config.get("segmentScales", {}) + non_default = { + k: v for k, v in segment_scales.items() if abs(v - 1.0) > 0.001 + } + if not non_default: + return None + return yaml.dump( + {"skeleton_editor": {"segment_scales": non_default}}, + default_flow_style=False, + sort_keys=False, + ) + + def export_stac_yaml(config: dict, output_path: str) -> None: """Export UI state to a STAC-compatible YAML file on disk.""" with open(output_path, "w") as f: diff --git a/frontend/src/api.ts b/frontend/src/api.ts index ca0cfa0..a097f53 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -83,6 +83,25 @@ export async function exportConfig(config: Record): Promise): Promise { + const resp = await fetch(`${BASE}/api/export-ui-sidecar`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ config }), + }); + if (resp.status === 204) return null; + if (!resp.ok) { + let msg = `HTTP ${resp.status}`; + try { + const err = await resp.json(); + if (err?.error) msg = err.error; + } catch { /* not JSON */ } + throw new Error(msg); + } + return resp.text(); +} + export async function alignToMujoco(data: Record) { const resp = await fetch(`${BASE}/api/align`, { method: "POST", diff --git a/frontend/src/components/Toolbar.tsx b/frontend/src/components/Toolbar.tsx index 185c3ff..455c41e 100644 --- a/frontend/src/components/Toolbar.tsx +++ b/frontend/src/components/Toolbar.tsx @@ -15,6 +15,19 @@ function pickFile(accept: string): Promise { }); } +/** Trigger a browser download for a YAML document. */ +function downloadYaml(body: string, filename: string) { + const blob = new Blob([body], { type: "application/x-yaml" }); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = filename; + document.body.appendChild(a); + a.click(); + a.remove(); + URL.revokeObjectURL(url); +} + export default function Toolbar() { const setXmlData = useStore((s) => s.setXmlData); const setAcmData = useStore((s) => s.setAcmData); @@ -91,7 +104,7 @@ export default function Toolbar() { for (const m of state.mappings) pairs[m.keypointName] = m.bodyName; const offsetMap: Record = {}; for (const o of state.offsets) offsetMap[o.keypointName] = [o.x, o.y, o.z]; - const config = { + const config: Record = { keypointModelPairs: pairs, keypointInitialOffsets: offsetMap, scaleFactor: state.scaleFactor, @@ -100,24 +113,28 @@ export default function Toolbar() { kpNames: state.acmKeypointNames, segmentScales: state.segmentScales, }; - let yamlBody: string; + if (state.rawTemplate) config._rawTemplate = state.rawTemplate; + + let mainBody: string; + let sidecarBody: string | null; try { - yamlBody = await api.exportConfig(config); + [mainBody, sidecarBody] = await Promise.all([ + api.exportConfig(config), + api.exportUiSidecar(config), + ]); } catch (e) { setIkStatus("Export error: " + (e as Error).message); return; } - // Trigger a browser download — no server-side filesystem write involved. - const blob = new Blob([yamlBody], { type: "application/x-yaml" }); - const url = URL.createObjectURL(blob); - const a = document.createElement("a"); - a.href = url; - a.download = "stac_retarget_config.yaml"; - document.body.appendChild(a); - a.click(); - a.remove(); - URL.revokeObjectURL(url); - setIkStatus("Config downloaded."); + downloadYaml(mainBody, "stac_retarget_config.yaml"); + if (sidecarBody) { + downloadYaml(sidecarBody, "stac_retarget_config.ui.yaml"); + } + setIkStatus( + sidecarBody + ? "Config + UI sidecar downloaded." + : "Config downloaded." + ); }, [setIkStatus]); const handleLoadStacOutput = useCallback(async () => { diff --git a/frontend/src/store.ts b/frontend/src/store.ts index ec9fda4..8d8238d 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -37,6 +37,10 @@ interface AppState { scaleFactor: number; mocapScaleFactor: number; + // Raw template from a loaded stac-mjx config — used on export to preserve + // fields the UI doesn't manage (N_ITERS, ROOT_OPTIMIZATION_KEYPOINT, ...). + rawTemplate: Record | null; + // Interaction mode: InteractionMode; selectedKeypoint: string | null; @@ -132,6 +136,7 @@ interface AppState { keypointInitialOffsets: Record; scaleFactor: number; mocapScaleFactor: number; + _rawTemplate?: Record; }) => void; } @@ -153,6 +158,7 @@ export const useStore = create()(persist((set) => ({ offsets: [], scaleFactor: 0.9, mocapScaleFactor: 0.01, + rawTemplate: null, mode: "mapping", selectedKeypoint: null, selectedBody: null, @@ -258,6 +264,7 @@ export const useStore = create()(persist((set) => ({ offsets: Object.entries(config.keypointInitialOffsets).map(([kp, [x, y, z]]) => ({ keypointName: kp, x, y, z })), scaleFactor: config.scaleFactor, mocapScaleFactor: config.mocapScaleFactor, + rawTemplate: config._rawTemplate ?? null, }), }), { name: "stac-retarget-ui-state", @@ -268,6 +275,7 @@ export const useStore = create()(persist((set) => ({ mappings: state.mappings, offsets: state.offsets, segmentScales: state.segmentScales, + rawTemplate: state.rawTemplate, // Model transform modelRotationY: state.modelRotationY, modelPosition: state.modelPosition, diff --git a/tests/test_config_io.py b/tests/test_config_io.py index 647ef78..3def0bc 100644 --- a/tests/test_config_io.py +++ b/tests/test_config_io.py @@ -4,7 +4,13 @@ from pathlib import Path import pytest -from backend.config_io import load_stac_yaml, export_stac_yaml +import yaml +from backend.config_io import ( + dump_stac_ui_sidecar, + dump_stac_yaml, + export_stac_yaml, + load_stac_yaml, +) YAML_PATH = os.environ.get( "STAC_KEYPOINTS_CONFIG", @@ -70,3 +76,208 @@ def test_load_stac_yaml_flat_format(tmp_path): assert result["mocapScaleFactor"] == 0.001 spine = result["keypointInitialOffsets"]["SpineF"] assert abs(spine[0] - -0.015) < 1e-6 + + +def test_load_includes_raw_template(tmp_path): + """load_stac_yaml returns the full parsed YAML as _rawTemplate.""" + src = textwrap.dedent( + """ + MJCF_PATH: "models/rodent.xml" + N_ITERS: 6 + SITES_TO_REGULARIZE: + - HandL + - HandR + KEYPOINT_MODEL_PAIRS: + Snout: skull + """ + ) + path = tmp_path / "with_extras.yaml" + path.write_text(src) + + result = load_stac_yaml(str(path)) + raw = result["_rawTemplate"] + assert raw["N_ITERS"] == 6 + assert raw["SITES_TO_REGULARIZE"] == ["HandL", "HandR"] + assert raw["KEYPOINT_MODEL_PAIRS"]["Snout"] == "skull" + + +def test_dump_with_flat_template_preserves_other_fields(tmp_path): + """Template-overlay export keeps non-UI fields (N_ITERS, SITES_TO_REGULARIZE, ...).""" + src = textwrap.dedent( + """ + MJCF_PATH: "models/rodent.xml" + N_ITERS: 6 + N_ITER_Q: 400 + ROOT_OPTIMIZATION_KEYPOINT: SpineL + SITES_TO_REGULARIZE: + - HandL + - HandR + KEYPOINT_MODEL_PAIRS: + Snout: skull + SpineF: vertebra_cervical_5 + KEYPOINT_INITIAL_OFFSETS: + Snout: 0. 0. 0. + SpineF: -0.015 0. 0.0 + KP_NAMES: [Snout, SpineF] + SCALE_FACTOR: 0.9 + MOCAP_SCALE_FACTOR: 0.001 + """ + ) + path = tmp_path / "flat.yaml" + path.write_text(src) + loaded = load_stac_yaml(str(path)) + + # Simulate a UI edit: reassign Snout to a different body, change scale. + loaded["keypointModelPairs"]["Snout"] = "head" + loaded["scaleFactor"] = 1.1 + + out_yaml = dump_stac_yaml(loaded) + out = yaml.safe_load(out_yaml) + + # UI edits applied + assert out["KEYPOINT_MODEL_PAIRS"]["Snout"] == "head" + assert out["SCALE_FACTOR"] == 1.1 + # Preserved from template + assert out["N_ITERS"] == 6 + assert out["N_ITER_Q"] == 400 + assert out["ROOT_OPTIMIZATION_KEYPOINT"] == "SpineL" + assert out["SITES_TO_REGULARIZE"] == ["HandL", "HandR"] + # Shape preserved: flat (not wrapped under `model:`) + assert "model" not in out + + +def test_dump_with_wrapped_template_preserves_shape(tmp_path): + """Wrapped UI-format templates round-trip as wrapped.""" + src = textwrap.dedent( + """ + model: + MJCF_PATH: models/rodent.xml + N_ITERS: 6 + KEYPOINT_MODEL_PAIRS: + Snout: skull + KEYPOINT_INITIAL_OFFSETS: + Snout: 0. 0. 0. + KP_NAMES: [Snout] + SCALE_FACTOR: 0.9 + MOCAP_SCALE_FACTOR: 0.001 + """ + ) + path = tmp_path / "wrapped.yaml" + path.write_text(src) + loaded = load_stac_yaml(str(path)) + + out = yaml.safe_load(dump_stac_yaml(loaded)) + assert "model" in out and "MJCF_PATH" in out["model"] + assert out["model"]["N_ITERS"] == 6 + + +def test_dump_strips_skeleton_editor_from_main_export(tmp_path): + """`skeleton_editor:` never leaks into the main stac-mjx export.""" + src = textwrap.dedent( + """ + MJCF_PATH: "models/rodent.xml" + N_ITERS: 6 + KEYPOINT_MODEL_PAIRS: {Snout: skull} + KEYPOINT_INITIAL_OFFSETS: {Snout: 0. 0. 0.} + KP_NAMES: [Snout] + SCALE_FACTOR: 0.9 + MOCAP_SCALE_FACTOR: 0.001 + skeleton_editor: + segment_scales: + 'SpineF->SpineM': 1.05 + """ + ) + path = tmp_path / "with_ui.yaml" + path.write_text(src) + loaded = load_stac_yaml(str(path)) + + out = yaml.safe_load(dump_stac_yaml(loaded)) + assert "skeleton_editor" not in out + + +def test_dump_without_template_emits_wrapped(): + """No template → UI-internal wrapped format (not a valid stac-mjx config).""" + config = { + "keypointModelPairs": {"Snout": "skull"}, + "keypointInitialOffsets": {"Snout": [0.0, 0.0, 0.0]}, + "scaleFactor": 0.9, + "mocapScaleFactor": 0.01, + "kpNames": ["Snout"], + "xmlPath": "models/rodent.xml", + } + out = yaml.safe_load(dump_stac_yaml(config)) + assert "model" in out + assert out["model"]["KEYPOINT_MODEL_PAIRS"]["Snout"] == "skull" + + +def test_dump_empty_ui_field_does_not_clobber_template(tmp_path): + """Exporting without loaded mocap must not wipe the template's KP_NAMES. + + Reproduces a bug where loading a config without any keypoint data loaded + yielded KP_NAMES=[] on export, overwriting the template's populated list. + """ + src = textwrap.dedent( + """ + MJCF_PATH: "models/rodent.xml" + N_ITERS: 6 + KEYPOINT_MODEL_PAIRS: + Snout: skull + SpineF: vertebra_cervical_5 + KEYPOINT_INITIAL_OFFSETS: + Snout: 0. 0. 0. + SpineF: 0. 0. 0. + KP_NAMES: + - Snout + - SpineF + - SpineM + SCALE_FACTOR: 0.9 + MOCAP_SCALE_FACTOR: 0.001 + """ + ) + path = tmp_path / "with_kp_names.yaml" + path.write_text(src) + loaded = load_stac_yaml(str(path)) + + # Simulate the UI state right after loading config but before loading any + # mocap data: kpNames/keypointModelPairs still present from the template, + # but if the user cleared them (or they were never populated in state), + # we must not clobber what the template already has. + loaded["kpNames"] = [] # UI hasn't loaded mocap → empty + out = yaml.safe_load(dump_stac_yaml(loaded)) + assert out["KP_NAMES"] == ["Snout", "SpineF", "SpineM"] + + +def test_dump_empty_field_when_template_also_empty(): + """If the template has no value either, emit whatever the UI has (incl. empty).""" + config = { + "keypointModelPairs": {"Snout": "skull"}, + "keypointInitialOffsets": {}, + "kpNames": [], + "scaleFactor": 0.9, + "mocapScaleFactor": 0.01, + "xmlPath": "models/rodent.xml", + "_rawTemplate": { + "MJCF_PATH": "models/rodent.xml", + "N_ITERS": 6, + }, + } + out = yaml.safe_load(dump_stac_yaml(config)) + # UI field overrides are applied when template has nothing to preserve. + assert out["KEYPOINT_MODEL_PAIRS"] == {"Snout": "skull"} + assert out["N_ITERS"] == 6 + + +def test_dump_ui_sidecar_none_when_default(): + """Sidecar returns None when there's no UI-only state worth saving.""" + assert dump_stac_ui_sidecar({"segmentScales": {}}) is None + assert dump_stac_ui_sidecar({"segmentScales": {"SpineF->SpineM": 1.0}}) is None + + +def test_dump_ui_sidecar_emits_non_default_scales(): + body = dump_stac_ui_sidecar( + {"segmentScales": {"SpineF->SpineM": 1.05, "HipL->KneeL": 1.0}} + ) + assert body is not None + parsed = yaml.safe_load(body) + # Non-default kept, default dropped. + assert parsed["skeleton_editor"]["segment_scales"] == {"SpineF->SpineM": 1.05}