Skip to content

Commit 8f0df90

Browse files
committed
refactor(moseq_train)
1 parent ba2d3c4 commit 8f0df90

File tree

1 file changed

+126
-62
lines changed

1 file changed

+126
-62
lines changed

element_moseq/moseq_train.py

Lines changed: 126 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
import yaml
2020
from element_interface.utils import find_full_path
2121

22+
# Configure JAX for better compatibility with DataJoint/DeepHash
23+
os.environ["JAX_ENABLE_X64"] = "False"
24+
os.environ["JAX_ARRAY"] = "False" # Use legacy array API for better compatibility
25+
os.environ["JAX_DYNAMIC_SHAPES"] = "False"
26+
2227
from .plotting import viz_utils
2328
from .readers import kpms_reader
2429

@@ -1490,7 +1495,7 @@ class File(dj.Part):
14901495
-> master
14911496
file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5', 'model_data.pkl').
14921497
---
1493-
file : filepath@moseq-train-processed # Path to the file in the processed data directory.
1498+
file_path : filepath@moseq-train-processed # Path to the file in the processed data directory.
14941499
"""
14951500

14961501
class Plots(dj.Part):
@@ -1523,6 +1528,10 @@ def make(self, key):
15231528
5. Reindex syllable labels by frequency.
15241529
6. Calculate fitting duration and insert results.
15251530
"""
1531+
import jax
1532+
1533+
jax.config.update("jax_enable_x64", True)
1534+
15261535
from keypoint_moseq import (
15271536
estimate_sigmasq_loc,
15281537
fit_model,
@@ -1548,6 +1557,10 @@ def make(self, key):
15481557
)
15491558

15501559
if task_mode == "trigger":
1560+
import pickle
1561+
1562+
from keypoint_moseq import load_checkpoint
1563+
15511564
pca_path = (PCAFit.File & key & 'file_name="pca.p"').fetch1("file_path")
15521565
pca = load_pca(Path(pca_path).parent.as_posix())
15531566
coordinates, confidences = (PreProcessing & key).fetch1(
@@ -1561,73 +1574,131 @@ def make(self, key):
15611574
metadata = pickle.load(open(metadata_path, "rb"))
15621575
average_frame_rate = (PreProcessing & key).fetch1("average_frame_rate")
15631576

1564-
kpms_dj_config_path = (PreProcessing.ConfigFile & key).fetch1("config_file")
1565-
kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
1566-
config_path=kpms_dj_config_path
1577+
kpms_dj_config_abs_path = (PreProcessing.ConfigFile & key).fetch1(
1578+
"config_file"
15671579
)
1568-
# Update kpms_dj_config file in disk with new latent_dim and kappa values
1569-
kpms_dj_config_dict = kpms_reader.update_kpms_dj_config(
1570-
config_dict=kpms_dj_config_dict,
1571-
latent_dim=full_latent_dim,
1572-
kappa=full_kappa,
1573-
sigmasq_loc=estimate_sigmasq_loc(
1580+
1581+
kpms_dj_config_dict_for_save = kpms_reader.load_kpms_dj_config(
1582+
config_path=kpms_dj_config_abs_path, build_indexes=False
1583+
)
1584+
sigmasq_loc_val = float(
1585+
estimate_sigmasq_loc(
15741586
data["Y"], data["mask"], filter_size=average_frame_rate
1575-
),
1587+
)
15761588
)
1577-
1578-
# Initialize the model
1579-
model = init_model(
1580-
data=data, metadata=metadata, pca=pca, **kpms_dj_config_dict
1589+
kpms_dj_config_dict_for_save = kpms_reader.update_kpms_dj_config(
1590+
config_dict=kpms_dj_config_dict_for_save,
1591+
config_path=kpms_dj_config_abs_path,
1592+
latent_dim=int(full_latent_dim),
1593+
kappa=float(full_kappa),
1594+
sigmasq_loc=sigmasq_loc_val,
15811595
)
1582-
# Update the model hyperparameters
1583-
model = update_hypparams(
1584-
model,
1585-
kappa=float(full_kappa.item()),
1586-
latent_dim=int(full_latent_dim.item()),
1596+
1597+
kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
1598+
config_path=kpms_dj_config_abs_path, build_indexes=True
15871599
)
1600+
15881601
# Determine model directory name for outputs
15891602
if model_name is None or not str(model_name).strip():
15901603
model_name = f"latent_dim_{full_latent_dim.item()}_kappa_{full_kappa.item()}_iters_{full_num_iterations.item()}"
15911604
else:
15921605
model_name = str(model_name)
15931606

1607+
# Try to load pre-fit model for the same latent_dim and kappa values
1608+
pre_model = None
1609+
# More optimal: check existence before fetching to avoid try/except
1610+
pre_model_key_query = (
1611+
PreFitTask
1612+
& {"kpset_id": key["kpset_id"], "bodyparts_id": key["bodyparts_id"]}
1613+
& {
1614+
"pre_kappa": key["full_kappa"],
1615+
"pre_latent_dim": key["full_latent_dim"],
1616+
}
1617+
)
1618+
if pre_model_key_query:
1619+
pre_model_key = pre_model_key_query.fetch1("KEY")
1620+
pre_model_file = (
1621+
PreFit.File & pre_model_key & 'file_name="checkpoint.h5"'
1622+
).fetch1("file_path")
1623+
with open(pre_model_file, "rb") as f:
1624+
pre_model = pickle.load(f)
1625+
logger.info(
1626+
f"Using PreFit model {pre_model_key_query} as warm start for FullFit"
1627+
)
1628+
15941629
execution_time = datetime.now(timezone.utc)
1630+
1631+
# Initialize model: Use PreFit if available, otherwise initialize fresh
1632+
try:
1633+
if pre_model is not None:
1634+
model_to_fit = pre_model
1635+
else:
1636+
# Only initialize fresh model if no PreFit available
1637+
model_to_fit = init_model(
1638+
data=data, metadata=metadata, pca=pca, **kpms_dj_config_dict
1639+
)
1640+
# Update the model hyperparameters
1641+
model_to_fit = update_hypparams(
1642+
model_to_fit,
1643+
kappa=float(full_kappa.item()),
1644+
latent_dim=int(full_latent_dim.item()),
1645+
)
1646+
except Exception as e:
1647+
raise ValueError(f"Model initialization failed: {e}")
1648+
15951649
# Fit the model
1596-
model, model_name = fit_model(
1597-
model=model,
1598-
model_name=model_name,
1599-
data=data,
1600-
metadata=metadata,
1601-
project_dir=kpms_project_output_dir.as_posix(),
1602-
ar_only=False,
1603-
num_iters=full_num_iterations,
1604-
generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/
1605-
save_every_n_iters=10,
1606-
)
1650+
try:
16071651

1608-
# Reindex the syllables in the checkpoint file
1609-
reindex_syllables_in_checkpoint(
1610-
project_dir=kpms_project_output_dir.as_posix(),
1611-
model_name=model_name,
1612-
)
1652+
model, model_name = fit_model(
1653+
model=model_to_fit,
1654+
model_name=model_name,
1655+
data=data,
1656+
metadata=metadata,
1657+
project_dir=kpms_project_output_dir.as_posix(),
1658+
ar_only=False,
1659+
num_iters=full_num_iterations,
1660+
generate_progress_plots=True,
1661+
save_every_n_iters=1, # TODO: to change to a higher value
1662+
verbose=False,
1663+
) # checkpoint will be saved at project_dir/model_name
1664+
except Exception as e:
1665+
raise ValueError(f"FullFit training failed: {e}")
1666+
1667+
try:
1668+
# Reindex the syllables in the checkpoint file
1669+
reindex_syllables_in_checkpoint(
1670+
project_dir=kpms_project_output_dir.as_posix(),
1671+
model_name=model_name,
1672+
)
1673+
except Exception as e:
1674+
raise ValueError(
1675+
f"Reindexing syllables failed due to FullFit training failure: {e}"
1676+
)
16131677

16141678
# Create a PNG version fo the PDF progress plot
1615-
png_path, pdf_path = viz_utils.copy_pdf_to_png(
1616-
kpms_project_output_dir, model_name
1617-
)
1618-
# Define model_name_full_path for checkpoint file search
16191679
model_name_full_path = find_full_path(kpms_project_output_dir, model_name)
1680+
pdf_path = model_name_full_path / "fitting_progress.pdf"
1681+
png_path = model_name_full_path / "fitting_progress.png"
1682+
1683+
if pdf_path.exists():
1684+
png_path, pdf_path = viz_utils.copy_pdf_to_png(
1685+
kpms_project_output_dir, model_name
1686+
)
1687+
else:
1688+
logger.warning(f"No progress PDF found at {pdf_path}")
16201689
else:
16211690
# Load mode must specify a model_name
16221691
if model_name is None or not str(model_name).strip():
1623-
raise ValueError("model_name is required when task_mode='load'")
1692+
raise ValueError("`model_name` is required when task_mode='load'")
1693+
16241694
model_name_full_path = find_full_path(kpms_project_output_dir, model_name)
16251695
pdf_path = model_name_full_path / "fitting_progress.pdf"
16261696
png_path = model_name_full_path / "fitting_progress.png"
16271697

16281698
# Get the path to the updated config file
1629-
kpms_dj_config_path = kpms_reader._kpms_dj_config_path(kpms_project_output_dir)
1630-
1699+
kpms_dj_config_abs_path = kpms_reader._kpms_dj_config_path(
1700+
kpms_project_output_dir
1701+
)
16311702
if not pdf_path.exists():
16321703
raise FileNotFoundError(f"PreFit PDF progress plot not found at {pdf_path}")
16331704
if not png_path.exists():
@@ -1644,21 +1715,21 @@ def make(self, key):
16441715
f"No checkpoint files found in {model_name_full_path}"
16451716
)
16461717

1647-
completion_time = datetime.now(timezone.utc)
1648-
1649-
if task_mode == "trigger":
1650-
duration_seconds = (completion_time - execution_time).total_seconds()
1651-
else:
1652-
duration_seconds = None
1653-
1654-
# Save model dictionary as pickle file
1718+
# Save model dictionary as pickle file in the model directory
16551719
model_data_filename = "model_data.pkl"
16561720
model_data_file = model_name_full_path / model_data_filename
16571721
with open(model_data_file, "wb") as f:
16581722
pickle.dump(model, f)
16591723

16601724
file_paths = [checkpoint_file, model_data_file]
16611725

1726+
completion_time = datetime.now(timezone.utc)
1727+
duration_seconds = (
1728+
(completion_time - execution_time).total_seconds()
1729+
if task_mode == "trigger"
1730+
else None
1731+
)
1732+
16621733
self.insert1(
16631734
{
16641735
**key,
@@ -1676,21 +1747,19 @@ def make(self, key):
16761747
{
16771748
**key,
16781749
"file_name": file.name,
1679-
"file": file.as_posix(),
1750+
"file_path": file.as_posix(),
16801751
}
16811752
for file in file_paths
16821753
]
16831754
)
16841755

1685-
# Insert config file
16861756
self.ConfigFile.insert1(
16871757
{
16881758
**key,
1689-
"config_file": kpms_dj_config_path,
1759+
"config_file": kpms_dj_config_abs_path,
16901760
}
16911761
)
16921762

1693-
# Insert plots
16941763
self.Plots.insert1(
16951764
{
16961765
**key,
@@ -1711,7 +1780,6 @@ class ModelScore(dj.Computed):
17111780
-> FullFit
17121781
---
17131782
score=NULL : float # Model score (MLL for single model)
1714-
std_error=NULL : float # Standard error of the model score
17151783
"""
17161784

17171785
def make(self, key):
@@ -1722,28 +1790,24 @@ def make(self, key):
17221790
# Get checkpoint file for this specific model
17231791
checkpoint_file = (
17241792
FullFit.File & key & 'file_name LIKE "%checkpoint.h5"'
1725-
).fetch1("file")
1793+
).fetch1("file_path")
17261794

17271795
# Load the checkpoint to get model data
17281796
model, data, _, _ = load_checkpoint(path=checkpoint_file)
17291797

1730-
# Compute marginal log likelihood for this model
1798+
# Compute marginal log likelihood for single model
17311799
mask = jnp.array(data["mask"])
17321800
x = jnp.array(model["states"]["x"])
17331801
Ab = jnp.array(model["params"]["Ab"])
17341802
Q = jnp.array(model["params"]["Q"])
17351803
pi = jnp.array(model["params"]["pi"])
1736-
1737-
# Compute marginal log likelihood - this is the correct metric for single models
17381804
mll = marginal_log_likelihood(mask, x, Ab, Q, pi)
17391805
score = float(mll) # Store as "score" - this is MLL
1740-
std_error = 0.0 # No standard error for single model MLL
17411806

17421807
self.insert1(
17431808
{
17441809
**key,
17451810
"score": score,
1746-
"std_error": std_error,
17471811
}
17481812
)
17491813

0 commit comments

Comments
 (0)