Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from backend.mujoco_utils import compute_body_transforms, extract_model_geometry
from backend.acm_processing import load_acm_trials, load_single_matfile, apply_retargeting
from backend.alignment import align_acm_to_mujoco
from backend.config_io import load_stac_yaml, dump_stac_yaml, load_stac_output_h5
from backend.config_io import (
load_stac_yaml,
dump_stac_yaml,
dump_stac_ui_sidecar,
load_stac_output_h5,
)
from backend.frame_selector import suggest_frames
from backend.stac_runner import run_quick_stac

Expand Down Expand Up @@ -194,6 +199,22 @@ async def export_config(data: dict):
return PlainTextResponse(body, media_type="application/x-yaml")


@app.post("/api/export-ui-sidecar")
async def export_ui_sidecar(data: dict):
"""Serialize UI-only state (skeleton editor, ...) as a sidecar YAML.

Returns 204 when there's nothing to save, so the frontend can skip the
download.
"""
try:
body = dump_stac_ui_sidecar(data["config"])
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
if body is None:
return PlainTextResponse("", status_code=204)
return PlainTextResponse(body, media_type="application/x-yaml")


@app.post("/api/load-stac-output")
async def load_stac_output(file: UploadFile = File(None), path: str = Query(None)):
"""Load STAC output H5 from a server-side path or an uploaded file."""
Expand Down
136 changes: 114 additions & 22 deletions backend/config_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""STAC YAML config and H5 import/export."""
from __future__ import annotations
import copy
from pathlib import Path
import yaml
import numpy as np
Expand All @@ -10,6 +11,21 @@
# (e.g. configs/model/rodent.yaml). Used to detect flat vs. wrapped shapes.
_MODEL_FIELD_MARKERS = ("KEYPOINT_MODEL_PAIRS", "KP_NAMES", "MJCF_PATH")

# Fields that the UI owns and will overwrite on export.
_UI_MANAGED_FIELDS = (
"MJCF_PATH",
"SCALE_FACTOR",
"MOCAP_SCALE_FACTOR",
"KP_NAMES",
"KEYPOINT_MODEL_PAIRS",
"KEYPOINT_INITIAL_OFFSETS",
)


def _is_flat(raw: dict) -> bool:
"""True if `raw` looks like a flat stac-mjx model config."""
return any(k in raw for k in _MODEL_FIELD_MARKERS)


def _extract_model_section(raw: dict) -> dict:
"""Return the dict containing model-level fields from a loaded YAML.
Expand All @@ -20,13 +36,40 @@ def _extract_model_section(raw: dict) -> dict:
the file into the `model` namespace during composition.
- Wrapped: the UI's own export, where everything is nested under `model:`.
"""
if any(k in raw for k in _MODEL_FIELD_MARKERS):
if _is_flat(raw):
return raw
return raw.get("model", {})


def _offsets_to_yaml(offsets: dict) -> dict:
"""Convert [x, y, z] offsets to space-separated strings (stac-mjx format)."""
return {kp: f"{v[0]} {v[1]} {v[2]}" for kp, v in offsets.items()}


def _ui_managed_fields(config: dict) -> dict:
"""Build the model-level dict of fields the UI owns, in canonical order."""
return {
"MJCF_PATH": config.get("xmlPath", ""),
"SCALE_FACTOR": config.get("scaleFactor", 0.9),
"MOCAP_SCALE_FACTOR": config.get("mocapScaleFactor", 0.01),
"KP_NAMES": config.get(
"kpNames", list(config.get("keypointModelPairs", {}).keys())
),
"KEYPOINT_MODEL_PAIRS": config.get("keypointModelPairs", {}),
"KEYPOINT_INITIAL_OFFSETS": _offsets_to_yaml(
config.get("keypointInitialOffsets", {})
),
}


def load_stac_yaml(path: str) -> dict:
"""Load STAC config YAML and return normalized dict for the UI."""
"""Load STAC config YAML and return normalized dict for the UI.

Returns:
Dict with UI-normalized fields (keypointModelPairs, keypointInitialOffsets,
scaleFactor, mocapScaleFactor, kpNames, xmlPath) plus `_rawTemplate`:
the full parsed YAML, for template-overlay export.
"""
with open(path) as f:
raw = yaml.safe_load(f) or {}
model = _extract_model_section(raw)
Expand All @@ -46,33 +89,82 @@ def load_stac_yaml(path: str) -> dict:
"mocapScaleFactor": float(model.get("MOCAP_SCALE_FACTOR", 0.01)),
"kpNames": list(model.get("KP_NAMES", [])),
"xmlPath": model.get("MJCF_PATH", ""),
"_rawTemplate": raw,
}


def _is_empty(v) -> bool:
"""Treat None and empty containers/strings as 'no UI data to contribute'."""
if v is None:
return True
if isinstance(v, (list, dict, str)):
return len(v) == 0
return False


def _overlay_onto_template(template: dict, ui_fields: dict) -> dict:
"""Overlay UI-managed fields onto a template, preserving its shape.

- Flat template → overlay at top level, preserving key order (UI fields
replace existing keys in place; new keys appended).
- Wrapped template → overlay under raw["model"].
- UI-only sections like `skeleton_editor` are stripped.
- Empty UI values (e.g. KP_NAMES=[] when no keypoints were loaded) do not
clobber a populated template field — otherwise exporting without
loading mocap would wipe the template's keypoint list.
"""
out = copy.deepcopy(template)
out.pop("skeleton_editor", None)

target = out if _is_flat(out) else out.setdefault("model", {})

for field in _UI_MANAGED_FIELDS:
value = ui_fields[field]
if _is_empty(value) and not _is_empty(target.get(field)):
continue
target[field] = value
return out


def dump_stac_yaml(config: dict) -> str:
"""Serialize UI state to STAC-compatible YAML and return it as a string."""
offsets_str = {}
for kp, vals in config.get("keypointInitialOffsets", {}).items():
offsets_str[kp] = f"{vals[0]} {vals[1]} {vals[2]}"
yaml_dict = {
"model": {
"MJCF_PATH": config.get("xmlPath", ""),
"SCALE_FACTOR": config.get("scaleFactor", 0.9),
"MOCAP_SCALE_FACTOR": config.get("mocapScaleFactor", 0.01),
"KP_NAMES": config.get("kpNames", list(config.get("keypointModelPairs", {}).keys())),
"KEYPOINT_MODEL_PAIRS": config.get("keypointModelPairs", {}),
"KEYPOINT_INITIAL_OFFSETS": offsets_str,
},
}
# Include segment scales if any are non-default
segment_scales = config.get("segmentScales", {})
if segment_scales:
non_default = {k: v for k, v in segment_scales.items() if abs(v - 1.0) > 0.001}
if non_default:
yaml_dict["skeleton_editor"] = {"segment_scales": non_default}
"""Serialize UI state to STAC-compatible YAML and return it as a string.

If `config` carries `_rawTemplate` (from a prior `load_stac_yaml`), overlay
the UI's edits onto it so fields the UI doesn't manage (N_ITERS,
ROOT_OPTIMIZATION_KEYPOINT, SITES_TO_REGULARIZE, ...) are preserved.

Without a template, emit a UI-wrapped shape (nested under `model:`). That
shape is the UI's internal round-trip format and is NOT a drop-in
stac-mjx config — use template-overlay for that.
"""
ui_fields = _ui_managed_fields(config)
template = config.get("_rawTemplate")
if template:
yaml_dict = _overlay_onto_template(template, ui_fields)
else:
yaml_dict = {"model": dict(ui_fields)}
return yaml.dump(yaml_dict, default_flow_style=False, sort_keys=False)


def dump_stac_ui_sidecar(config: dict) -> str | None:
"""Serialize UI-only state (skeleton editor) to its own YAML.

Returns None when there's nothing to save — the caller should skip the
sidecar download in that case rather than emitting an empty file.
"""
segment_scales = config.get("segmentScales", {})
non_default = {
k: v for k, v in segment_scales.items() if abs(v - 1.0) > 0.001
}
if not non_default:
return None
return yaml.dump(
{"skeleton_editor": {"segment_scales": non_default}},
default_flow_style=False,
sort_keys=False,
)


def export_stac_yaml(config: dict, output_path: str) -> None:
"""Export UI state to a STAC-compatible YAML file on disk."""
with open(output_path, "w") as f:
Expand Down
19 changes: 19 additions & 0 deletions frontend/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,25 @@ export async function exportConfig(config: Record<string, unknown>): Promise<str
return resp.text();
}

/** UI-only sidecar (skeleton editor, ...). Returns null when there's nothing to save. */
export async function exportUiSidecar(config: Record<string, unknown>): Promise<string | null> {
const resp = await fetch(`${BASE}/api/export-ui-sidecar`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ config }),
});
if (resp.status === 204) return null;
if (!resp.ok) {
let msg = `HTTP ${resp.status}`;
try {
const err = await resp.json();
if (err?.error) msg = err.error;
} catch { /* not JSON */ }
throw new Error(msg);
}
return resp.text();
}

export async function alignToMujoco(data: Record<string, unknown>) {
const resp = await fetch(`${BASE}/api/align`, {
method: "POST",
Expand Down
45 changes: 31 additions & 14 deletions frontend/src/components/Toolbar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ function pickFile(accept: string): Promise<File | null> {
});
}

/** Trigger a browser download for a YAML document. */
function downloadYaml(body: string, filename: string) {
const blob = new Blob([body], { type: "application/x-yaml" });
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = filename;
document.body.appendChild(a);
a.click();
a.remove();
URL.revokeObjectURL(url);
}

export default function Toolbar() {
const setXmlData = useStore((s) => s.setXmlData);
const setAcmData = useStore((s) => s.setAcmData);
Expand Down Expand Up @@ -91,7 +104,7 @@ export default function Toolbar() {
for (const m of state.mappings) pairs[m.keypointName] = m.bodyName;
const offsetMap: Record<string, [number, number, number]> = {};
for (const o of state.offsets) offsetMap[o.keypointName] = [o.x, o.y, o.z];
const config = {
const config: Record<string, unknown> = {
keypointModelPairs: pairs,
keypointInitialOffsets: offsetMap,
scaleFactor: state.scaleFactor,
Expand All @@ -100,24 +113,28 @@ export default function Toolbar() {
kpNames: state.acmKeypointNames,
segmentScales: state.segmentScales,
};
let yamlBody: string;
if (state.rawTemplate) config._rawTemplate = state.rawTemplate;

let mainBody: string;
let sidecarBody: string | null;
try {
yamlBody = await api.exportConfig(config);
[mainBody, sidecarBody] = await Promise.all([
api.exportConfig(config),
api.exportUiSidecar(config),
]);
} catch (e) {
setIkStatus("Export error: " + (e as Error).message);
return;
}
// Trigger a browser download — no server-side filesystem write involved.
const blob = new Blob([yamlBody], { type: "application/x-yaml" });
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = "stac_retarget_config.yaml";
document.body.appendChild(a);
a.click();
a.remove();
URL.revokeObjectURL(url);
setIkStatus("Config downloaded.");
downloadYaml(mainBody, "stac_retarget_config.yaml");
if (sidecarBody) {
downloadYaml(sidecarBody, "stac_retarget_config.ui.yaml");
}
setIkStatus(
sidecarBody
? "Config + UI sidecar downloaded."
: "Config downloaded."
);
}, [setIkStatus]);

const handleLoadStacOutput = useCallback(async () => {
Expand Down
8 changes: 8 additions & 0 deletions frontend/src/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ interface AppState {
scaleFactor: number;
mocapScaleFactor: number;

// Raw template from a loaded stac-mjx config — used on export to preserve
// fields the UI doesn't manage (N_ITERS, ROOT_OPTIMIZATION_KEYPOINT, ...).
rawTemplate: Record<string, unknown> | null;

// Interaction
mode: InteractionMode;
selectedKeypoint: string | null;
Expand Down Expand Up @@ -132,6 +136,7 @@ interface AppState {
keypointInitialOffsets: Record<string, [number, number, number]>;
scaleFactor: number;
mocapScaleFactor: number;
_rawTemplate?: Record<string, unknown>;
}) => void;
}

Expand All @@ -153,6 +158,7 @@ export const useStore = create<AppState>()(persist((set) => ({
offsets: [],
scaleFactor: 0.9,
mocapScaleFactor: 0.01,
rawTemplate: null,
mode: "mapping",
selectedKeypoint: null,
selectedBody: null,
Expand Down Expand Up @@ -258,6 +264,7 @@ export const useStore = create<AppState>()(persist((set) => ({
offsets: Object.entries(config.keypointInitialOffsets).map(([kp, [x, y, z]]) => ({ keypointName: kp, x, y, z })),
scaleFactor: config.scaleFactor,
mocapScaleFactor: config.mocapScaleFactor,
rawTemplate: config._rawTemplate ?? null,
}),
}), {
name: "stac-retarget-ui-state",
Expand All @@ -268,6 +275,7 @@ export const useStore = create<AppState>()(persist((set) => ({
mappings: state.mappings,
offsets: state.offsets,
segmentScales: state.segmentScales,
rawTemplate: state.rawTemplate,
// Model transform
modelRotationY: state.modelRotationY,
modelPosition: state.modelPosition,
Expand Down
Loading
Loading