Skip to content

Commit a327734

Browse files
authored
Merge pull request #3 from HugoFara/feat-template-overlay-export
feat(config_io): template-overlay export + UI sidecar
2 parents c7cdc91 + 808e880 commit a327734

6 files changed

Lines changed: 406 additions & 38 deletions

File tree

backend/app.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from backend.mujoco_utils import compute_body_transforms, extract_model_geometry
1818
from backend.acm_processing import load_acm_trials, load_single_matfile, apply_retargeting
1919
from backend.alignment import align_acm_to_mujoco
20-
from backend.config_io import load_stac_yaml, dump_stac_yaml, load_stac_output_h5
20+
from backend.config_io import (
21+
load_stac_yaml,
22+
dump_stac_yaml,
23+
dump_stac_ui_sidecar,
24+
load_stac_output_h5,
25+
)
2126
from backend.frame_selector import suggest_frames
2227
from backend.stac_runner import run_quick_stac
2328

@@ -194,6 +199,22 @@ async def export_config(data: dict):
194199
return PlainTextResponse(body, media_type="application/x-yaml")
195200

196201

202+
@app.post("/api/export-ui-sidecar")
203+
async def export_ui_sidecar(data: dict):
204+
"""Serialize UI-only state (skeleton editor, ...) as a sidecar YAML.
205+
206+
Returns 204 when there's nothing to save, so the frontend can skip the
207+
download.
208+
"""
209+
try:
210+
body = dump_stac_ui_sidecar(data["config"])
211+
except Exception as e:
212+
return JSONResponse({"error": str(e)}, status_code=500)
213+
if body is None:
214+
return PlainTextResponse("", status_code=204)
215+
return PlainTextResponse(body, media_type="application/x-yaml")
216+
217+
197218
@app.post("/api/load-stac-output")
198219
async def load_stac_output(file: UploadFile = File(None), path: str = Query(None)):
199220
"""Load STAC output H5 from a server-side path or an uploaded file."""

backend/config_io.py

Lines changed: 114 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""STAC YAML config and H5 import/export."""
22
from __future__ import annotations
3+
import copy
34
from pathlib import Path
45
import yaml
56
import numpy as np
@@ -10,6 +11,21 @@
1011
# (e.g. configs/model/rodent.yaml). Used to detect flat vs. wrapped shapes.
1112
_MODEL_FIELD_MARKERS = ("KEYPOINT_MODEL_PAIRS", "KP_NAMES", "MJCF_PATH")
1213

14+
# Fields that the UI owns and will overwrite on export.
15+
_UI_MANAGED_FIELDS = (
16+
"MJCF_PATH",
17+
"SCALE_FACTOR",
18+
"MOCAP_SCALE_FACTOR",
19+
"KP_NAMES",
20+
"KEYPOINT_MODEL_PAIRS",
21+
"KEYPOINT_INITIAL_OFFSETS",
22+
)
23+
24+
25+
def _is_flat(raw: dict) -> bool:
26+
"""True if `raw` looks like a flat stac-mjx model config."""
27+
return any(k in raw for k in _MODEL_FIELD_MARKERS)
28+
1329

1430
def _extract_model_section(raw: dict) -> dict:
1531
"""Return the dict containing model-level fields from a loaded YAML.
@@ -20,13 +36,40 @@ def _extract_model_section(raw: dict) -> dict:
2036
the file into the `model` namespace during composition.
2137
- Wrapped: the UI's own export, where everything is nested under `model:`.
2238
"""
23-
if any(k in raw for k in _MODEL_FIELD_MARKERS):
39+
if _is_flat(raw):
2440
return raw
2541
return raw.get("model", {})
2642

2743

44+
def _offsets_to_yaml(offsets: dict) -> dict:
45+
"""Convert [x, y, z] offsets to space-separated strings (stac-mjx format)."""
46+
return {kp: f"{v[0]} {v[1]} {v[2]}" for kp, v in offsets.items()}
47+
48+
49+
def _ui_managed_fields(config: dict) -> dict:
50+
"""Build the model-level dict of fields the UI owns, in canonical order."""
51+
return {
52+
"MJCF_PATH": config.get("xmlPath", ""),
53+
"SCALE_FACTOR": config.get("scaleFactor", 0.9),
54+
"MOCAP_SCALE_FACTOR": config.get("mocapScaleFactor", 0.01),
55+
"KP_NAMES": config.get(
56+
"kpNames", list(config.get("keypointModelPairs", {}).keys())
57+
),
58+
"KEYPOINT_MODEL_PAIRS": config.get("keypointModelPairs", {}),
59+
"KEYPOINT_INITIAL_OFFSETS": _offsets_to_yaml(
60+
config.get("keypointInitialOffsets", {})
61+
),
62+
}
63+
64+
2865
def load_stac_yaml(path: str) -> dict:
29-
"""Load STAC config YAML and return normalized dict for the UI."""
66+
"""Load STAC config YAML and return normalized dict for the UI.
67+
68+
Returns:
69+
Dict with UI-normalized fields (keypointModelPairs, keypointInitialOffsets,
70+
scaleFactor, mocapScaleFactor, kpNames, xmlPath) plus `_rawTemplate`:
71+
the full parsed YAML, for template-overlay export.
72+
"""
3073
with open(path) as f:
3174
raw = yaml.safe_load(f) or {}
3275
model = _extract_model_section(raw)
@@ -46,33 +89,82 @@ def load_stac_yaml(path: str) -> dict:
4689
"mocapScaleFactor": float(model.get("MOCAP_SCALE_FACTOR", 0.01)),
4790
"kpNames": list(model.get("KP_NAMES", [])),
4891
"xmlPath": model.get("MJCF_PATH", ""),
92+
"_rawTemplate": raw,
4993
}
5094

5195

96+
def _is_empty(v) -> bool:
97+
"""Treat None and empty containers/strings as 'no UI data to contribute'."""
98+
if v is None:
99+
return True
100+
if isinstance(v, (list, dict, str)):
101+
return len(v) == 0
102+
return False
103+
104+
105+
def _overlay_onto_template(template: dict, ui_fields: dict) -> dict:
106+
"""Overlay UI-managed fields onto a template, preserving its shape.
107+
108+
- Flat template → overlay at top level, preserving key order (UI fields
109+
replace existing keys in place; new keys appended).
110+
- Wrapped template → overlay under raw["model"].
111+
- UI-only sections like `skeleton_editor` are stripped.
112+
- Empty UI values (e.g. KP_NAMES=[] when no keypoints were loaded) do not
113+
clobber a populated template field — otherwise exporting without
114+
loading mocap would wipe the template's keypoint list.
115+
"""
116+
out = copy.deepcopy(template)
117+
out.pop("skeleton_editor", None)
118+
119+
target = out if _is_flat(out) else out.setdefault("model", {})
120+
121+
for field in _UI_MANAGED_FIELDS:
122+
value = ui_fields[field]
123+
if _is_empty(value) and not _is_empty(target.get(field)):
124+
continue
125+
target[field] = value
126+
return out
127+
128+
52129
def dump_stac_yaml(config: dict) -> str:
53-
"""Serialize UI state to STAC-compatible YAML and return it as a string."""
54-
offsets_str = {}
55-
for kp, vals in config.get("keypointInitialOffsets", {}).items():
56-
offsets_str[kp] = f"{vals[0]} {vals[1]} {vals[2]}"
57-
yaml_dict = {
58-
"model": {
59-
"MJCF_PATH": config.get("xmlPath", ""),
60-
"SCALE_FACTOR": config.get("scaleFactor", 0.9),
61-
"MOCAP_SCALE_FACTOR": config.get("mocapScaleFactor", 0.01),
62-
"KP_NAMES": config.get("kpNames", list(config.get("keypointModelPairs", {}).keys())),
63-
"KEYPOINT_MODEL_PAIRS": config.get("keypointModelPairs", {}),
64-
"KEYPOINT_INITIAL_OFFSETS": offsets_str,
65-
},
66-
}
67-
# Include segment scales if any are non-default
68-
segment_scales = config.get("segmentScales", {})
69-
if segment_scales:
70-
non_default = {k: v for k, v in segment_scales.items() if abs(v - 1.0) > 0.001}
71-
if non_default:
72-
yaml_dict["skeleton_editor"] = {"segment_scales": non_default}
130+
"""Serialize UI state to STAC-compatible YAML and return it as a string.
131+
132+
If `config` carries `_rawTemplate` (from a prior `load_stac_yaml`), overlay
133+
the UI's edits onto it so fields the UI doesn't manage (N_ITERS,
134+
ROOT_OPTIMIZATION_KEYPOINT, SITES_TO_REGULARIZE, ...) are preserved.
135+
136+
Without a template, emit a UI-wrapped shape (nested under `model:`). That
137+
shape is the UI's internal round-trip format and is NOT a drop-in
138+
stac-mjx config — use template-overlay for that.
139+
"""
140+
ui_fields = _ui_managed_fields(config)
141+
template = config.get("_rawTemplate")
142+
if template:
143+
yaml_dict = _overlay_onto_template(template, ui_fields)
144+
else:
145+
yaml_dict = {"model": dict(ui_fields)}
73146
return yaml.dump(yaml_dict, default_flow_style=False, sort_keys=False)
74147

75148

149+
def dump_stac_ui_sidecar(config: dict) -> str | None:
150+
"""Serialize UI-only state (skeleton editor) to its own YAML.
151+
152+
Returns None when there's nothing to save — the caller should skip the
153+
sidecar download in that case rather than emitting an empty file.
154+
"""
155+
segment_scales = config.get("segmentScales", {})
156+
non_default = {
157+
k: v for k, v in segment_scales.items() if abs(v - 1.0) > 0.001
158+
}
159+
if not non_default:
160+
return None
161+
return yaml.dump(
162+
{"skeleton_editor": {"segment_scales": non_default}},
163+
default_flow_style=False,
164+
sort_keys=False,
165+
)
166+
167+
76168
def export_stac_yaml(config: dict, output_path: str) -> None:
77169
"""Export UI state to a STAC-compatible YAML file on disk."""
78170
with open(output_path, "w") as f:

frontend/src/api.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,25 @@ export async function exportConfig(config: Record<string, unknown>): Promise<str
8383
return resp.text();
8484
}
8585

86+
/** UI-only sidecar (skeleton editor, ...). Returns null when there's nothing to save. */
87+
export async function exportUiSidecar(config: Record<string, unknown>): Promise<string | null> {
88+
const resp = await fetch(`${BASE}/api/export-ui-sidecar`, {
89+
method: "POST",
90+
headers: { "Content-Type": "application/json" },
91+
body: JSON.stringify({ config }),
92+
});
93+
if (resp.status === 204) return null;
94+
if (!resp.ok) {
95+
let msg = `HTTP ${resp.status}`;
96+
try {
97+
const err = await resp.json();
98+
if (err?.error) msg = err.error;
99+
} catch { /* not JSON */ }
100+
throw new Error(msg);
101+
}
102+
return resp.text();
103+
}
104+
86105
export async function alignToMujoco(data: Record<string, unknown>) {
87106
const resp = await fetch(`${BASE}/api/align`, {
88107
method: "POST",

frontend/src/components/Toolbar.tsx

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@ function pickFile(accept: string): Promise<File | null> {
1515
});
1616
}
1717

18+
/** Trigger a browser download for a YAML document. */
19+
function downloadYaml(body: string, filename: string) {
20+
const blob = new Blob([body], { type: "application/x-yaml" });
21+
const url = URL.createObjectURL(blob);
22+
const a = document.createElement("a");
23+
a.href = url;
24+
a.download = filename;
25+
document.body.appendChild(a);
26+
a.click();
27+
a.remove();
28+
URL.revokeObjectURL(url);
29+
}
30+
1831
export default function Toolbar() {
1932
const setXmlData = useStore((s) => s.setXmlData);
2033
const setAcmData = useStore((s) => s.setAcmData);
@@ -91,7 +104,7 @@ export default function Toolbar() {
91104
for (const m of state.mappings) pairs[m.keypointName] = m.bodyName;
92105
const offsetMap: Record<string, [number, number, number]> = {};
93106
for (const o of state.offsets) offsetMap[o.keypointName] = [o.x, o.y, o.z];
94-
const config = {
107+
const config: Record<string, unknown> = {
95108
keypointModelPairs: pairs,
96109
keypointInitialOffsets: offsetMap,
97110
scaleFactor: state.scaleFactor,
@@ -100,24 +113,28 @@ export default function Toolbar() {
100113
kpNames: state.acmKeypointNames,
101114
segmentScales: state.segmentScales,
102115
};
103-
let yamlBody: string;
116+
if (state.rawTemplate) config._rawTemplate = state.rawTemplate;
117+
118+
let mainBody: string;
119+
let sidecarBody: string | null;
104120
try {
105-
yamlBody = await api.exportConfig(config);
121+
[mainBody, sidecarBody] = await Promise.all([
122+
api.exportConfig(config),
123+
api.exportUiSidecar(config),
124+
]);
106125
} catch (e) {
107126
setIkStatus("Export error: " + (e as Error).message);
108127
return;
109128
}
110-
// Trigger a browser download — no server-side filesystem write involved.
111-
const blob = new Blob([yamlBody], { type: "application/x-yaml" });
112-
const url = URL.createObjectURL(blob);
113-
const a = document.createElement("a");
114-
a.href = url;
115-
a.download = "stac_retarget_config.yaml";
116-
document.body.appendChild(a);
117-
a.click();
118-
a.remove();
119-
URL.revokeObjectURL(url);
120-
setIkStatus("Config downloaded.");
129+
downloadYaml(mainBody, "stac_retarget_config.yaml");
130+
if (sidecarBody) {
131+
downloadYaml(sidecarBody, "stac_retarget_config.ui.yaml");
132+
}
133+
setIkStatus(
134+
sidecarBody
135+
? "Config + UI sidecar downloaded."
136+
: "Config downloaded."
137+
);
121138
}, [setIkStatus]);
122139

123140
const handleLoadStacOutput = useCallback(async () => {

frontend/src/store.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ interface AppState {
3737
scaleFactor: number;
3838
mocapScaleFactor: number;
3939

40+
// Raw template from a loaded stac-mjx config — used on export to preserve
41+
// fields the UI doesn't manage (N_ITERS, ROOT_OPTIMIZATION_KEYPOINT, ...).
42+
rawTemplate: Record<string, unknown> | null;
43+
4044
// Interaction
4145
mode: InteractionMode;
4246
selectedKeypoint: string | null;
@@ -132,6 +136,7 @@ interface AppState {
132136
keypointInitialOffsets: Record<string, [number, number, number]>;
133137
scaleFactor: number;
134138
mocapScaleFactor: number;
139+
_rawTemplate?: Record<string, unknown>;
135140
}) => void;
136141
}
137142

@@ -153,6 +158,7 @@ export const useStore = create<AppState>()(persist((set) => ({
153158
offsets: [],
154159
scaleFactor: 0.9,
155160
mocapScaleFactor: 0.01,
161+
rawTemplate: null,
156162
mode: "mapping",
157163
selectedKeypoint: null,
158164
selectedBody: null,
@@ -258,6 +264,7 @@ export const useStore = create<AppState>()(persist((set) => ({
258264
offsets: Object.entries(config.keypointInitialOffsets).map(([kp, [x, y, z]]) => ({ keypointName: kp, x, y, z })),
259265
scaleFactor: config.scaleFactor,
260266
mocapScaleFactor: config.mocapScaleFactor,
267+
rawTemplate: config._rawTemplate ?? null,
261268
}),
262269
}), {
263270
name: "stac-retarget-ui-state",
@@ -268,6 +275,7 @@ export const useStore = create<AppState>()(persist((set) => ({
268275
mappings: state.mappings,
269276
offsets: state.offsets,
270277
segmentScales: state.segmentScales,
278+
rawTemplate: state.rawTemplate,
271279
// Model transform
272280
modelRotationY: state.modelRotationY,
273281
modelPosition: state.modelPosition,

0 commit comments

Comments
 (0)