1919import yaml
2020from element_interface .utils import find_full_path
2121
22+ # Configure JAX for better compatibility with DataJoint/DeepHash
23+ os .environ ["JAX_ENABLE_X64" ] = "False"
24+ os .environ ["JAX_ARRAY" ] = "False" # Use legacy array API for better compatibility
25+ os .environ ["JAX_DYNAMIC_SHAPES" ] = "False"
26+
2227from .plotting import viz_utils
2328from .readers import kpms_reader
2429
@@ -1490,7 +1495,7 @@ class File(dj.Part):
14901495 -> master
14911496 file_name : varchar(255) # Name of the output file (e.g. 'checkpoint.h5', 'model_data.pkl').
14921497 ---
1493- file : filepath@moseq-train-processed # Path to the file in the processed data directory.
1498+ file_path : filepath@moseq-train-processed # Path to the file in the processed data directory.
14941499 """
14951500
14961501 class Plots (dj .Part ):
@@ -1523,6 +1528,10 @@ def make(self, key):
15231528 5. Reindex syllable labels by frequency.
15241529 6. Calculate fitting duration and insert results.
15251530 """
1531+ import jax
1532+
1533+ jax .config .update ("jax_enable_x64" , True )
1534+
15261535 from keypoint_moseq import (
15271536 estimate_sigmasq_loc ,
15281537 fit_model ,
@@ -1548,6 +1557,10 @@ def make(self, key):
15481557 )
15491558
15501559 if task_mode == "trigger" :
1560+ import pickle
1561+
1562+ from keypoint_moseq import load_checkpoint
1563+
15511564 pca_path = (PCAFit .File & key & 'file_name="pca.p"' ).fetch1 ("file_path" )
15521565 pca = load_pca (Path (pca_path ).parent .as_posix ())
15531566 coordinates , confidences = (PreProcessing & key ).fetch1 (
@@ -1561,73 +1574,131 @@ def make(self, key):
15611574 metadata = pickle .load (open (metadata_path , "rb" ))
15621575 average_frame_rate = (PreProcessing & key ).fetch1 ("average_frame_rate" )
15631576
1564- kpms_dj_config_path = (PreProcessing .ConfigFile & key ).fetch1 ("config_file" )
1565- kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (
1566- config_path = kpms_dj_config_path
1577+ kpms_dj_config_abs_path = (PreProcessing .ConfigFile & key ).fetch1 (
1578+ "config_file"
15671579 )
1568- # Update kpms_dj_config file in disk with new latent_dim and kappa values
1569- kpms_dj_config_dict = kpms_reader .update_kpms_dj_config (
1570- config_dict = kpms_dj_config_dict ,
1571- latent_dim = full_latent_dim ,
1572- kappa = full_kappa ,
1573- sigmasq_loc = estimate_sigmasq_loc (
1580+
1581+ kpms_dj_config_dict_for_save = kpms_reader .load_kpms_dj_config (
1582+ config_path = kpms_dj_config_abs_path , build_indexes = False
1583+ )
1584+ sigmasq_loc_val = float (
1585+ estimate_sigmasq_loc (
15741586 data ["Y" ], data ["mask" ], filter_size = average_frame_rate
1575- ),
1587+ )
15761588 )
1577-
1578- # Initialize the model
1579- model = init_model (
1580- data = data , metadata = metadata , pca = pca , ** kpms_dj_config_dict
1589+ kpms_dj_config_dict_for_save = kpms_reader .update_kpms_dj_config (
1590+ config_dict = kpms_dj_config_dict_for_save ,
1591+ config_path = kpms_dj_config_abs_path ,
1592+ latent_dim = int (full_latent_dim ),
1593+ kappa = float (full_kappa ),
1594+ sigmasq_loc = sigmasq_loc_val ,
15811595 )
1582- # Update the model hyperparameters
1583- model = update_hypparams (
1584- model ,
1585- kappa = float (full_kappa .item ()),
1586- latent_dim = int (full_latent_dim .item ()),
1596+
1597+ kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (
1598+ config_path = kpms_dj_config_abs_path , build_indexes = True
15871599 )
1600+
15881601 # Determine model directory name for outputs
15891602 if model_name is None or not str (model_name ).strip ():
15901603 model_name = f"latent_dim_{ full_latent_dim .item ()} _kappa_{ full_kappa .item ()} _iters_{ full_num_iterations .item ()} "
15911604 else :
15921605 model_name = str (model_name )
15931606
1607+ # Try to load pre-fit model for the same latent_dim and kappa values
1608+ pre_model = None
1609+ # More optimal: check existence before fetching to avoid try/except
1610+ pre_model_key_query = (
1611+ PreFitTask
1612+ & {"kpset_id" : key ["kpset_id" ], "bodyparts_id" : key ["bodyparts_id" ]}
1613+ & {
1614+ "pre_kappa" : key ["full_kappa" ],
1615+ "pre_latent_dim" : key ["full_latent_dim" ],
1616+ }
1617+ )
1618+ if pre_model_key_query :
1619+ pre_model_key = pre_model_key_query .fetch1 ("KEY" )
1620+ pre_model_file = (
1621+ PreFit .File & pre_model_key & 'file_name="checkpoint.h5"'
1622+ ).fetch1 ("file_path" )
1623+ with open (pre_model_file , "rb" ) as f :
1624+ pre_model = pickle .load (f )
1625+ logger .info (
1626+ f"Using PreFit model { pre_model_key_query } as warm start for FullFit"
1627+ )
1628+
15941629 execution_time = datetime .now (timezone .utc )
1630+
1631+ # Initialize model: Use PreFit if available, otherwise initialize fresh
1632+ try :
1633+ if pre_model is not None :
1634+ model_to_fit = pre_model
1635+ else :
1636+ # Only initialize fresh model if no PreFit available
1637+ model_to_fit = init_model (
1638+ data = data , metadata = metadata , pca = pca , ** kpms_dj_config_dict
1639+ )
1640+ # Update the model hyperparameters
1641+ model_to_fit = update_hypparams (
1642+ model_to_fit ,
1643+ kappa = float (full_kappa .item ()),
1644+ latent_dim = int (full_latent_dim .item ()),
1645+ )
1646+ except Exception as e :
1647+ raise ValueError (f"Model initialization failed: { e } " )
1648+
15951649 # Fit the model
1596- model , model_name = fit_model (
1597- model = model ,
1598- model_name = model_name ,
1599- data = data ,
1600- metadata = metadata ,
1601- project_dir = kpms_project_output_dir .as_posix (),
1602- ar_only = False ,
1603- num_iters = full_num_iterations ,
1604- generate_progress_plots = True , # saved to {project_dir}/{model_name}/plots/
1605- save_every_n_iters = 10 ,
1606- )
1650+ try :
16071651
1608- # Reindex the syllables in the checkpoint file
1609- reindex_syllables_in_checkpoint (
1610- project_dir = kpms_project_output_dir .as_posix (),
1611- model_name = model_name ,
1612- )
1652+ model , model_name = fit_model (
1653+ model = model_to_fit ,
1654+ model_name = model_name ,
1655+ data = data ,
1656+ metadata = metadata ,
1657+ project_dir = kpms_project_output_dir .as_posix (),
1658+ ar_only = False ,
1659+ num_iters = full_num_iterations ,
1660+ generate_progress_plots = True ,
1661+ save_every_n_iters = 1 , # TODO: to change to a higher value
1662+ verbose = False ,
1663+ ) # checkpoint will be saved at project_dir/model_name
1664+ except Exception as e :
1665+ raise ValueError (f"FullFit training failed: { e } " )
1666+
1667+ try :
1668+ # Reindex the syllables in the checkpoint file
1669+ reindex_syllables_in_checkpoint (
1670+ project_dir = kpms_project_output_dir .as_posix (),
1671+ model_name = model_name ,
1672+ )
1673+ except Exception as e :
1674+ raise ValueError (
1675+ f"Reindexing syllables failed due to FullFit training failure: { e } "
1676+ )
16131677
16141678 # Create a PNG version fo the PDF progress plot
1615- png_path , pdf_path = viz_utils .copy_pdf_to_png (
1616- kpms_project_output_dir , model_name
1617- )
1618- # Define model_name_full_path for checkpoint file search
16191679 model_name_full_path = find_full_path (kpms_project_output_dir , model_name )
1680+ pdf_path = model_name_full_path / "fitting_progress.pdf"
1681+ png_path = model_name_full_path / "fitting_progress.png"
1682+
1683+ if pdf_path .exists ():
1684+ png_path , pdf_path = viz_utils .copy_pdf_to_png (
1685+ kpms_project_output_dir , model_name
1686+ )
1687+ else :
1688+ logger .warning (f"No progress PDF found at { pdf_path } " )
16201689 else :
16211690 # Load mode must specify a model_name
16221691 if model_name is None or not str (model_name ).strip ():
1623- raise ValueError ("model_name is required when task_mode='load'" )
1692+ raise ValueError ("`model_name` is required when task_mode='load'" )
1693+
16241694 model_name_full_path = find_full_path (kpms_project_output_dir , model_name )
16251695 pdf_path = model_name_full_path / "fitting_progress.pdf"
16261696 png_path = model_name_full_path / "fitting_progress.png"
16271697
16281698 # Get the path to the updated config file
1629- kpms_dj_config_path = kpms_reader ._kpms_dj_config_path (kpms_project_output_dir )
1630-
1699+ kpms_dj_config_abs_path = kpms_reader ._kpms_dj_config_path (
1700+ kpms_project_output_dir
1701+ )
16311702 if not pdf_path .exists ():
16321703 raise FileNotFoundError (f"PreFit PDF progress plot not found at { pdf_path } " )
16331704 if not png_path .exists ():
@@ -1644,21 +1715,21 @@ def make(self, key):
16441715 f"No checkpoint files found in { model_name_full_path } "
16451716 )
16461717
1647- completion_time = datetime .now (timezone .utc )
1648-
1649- if task_mode == "trigger" :
1650- duration_seconds = (completion_time - execution_time ).total_seconds ()
1651- else :
1652- duration_seconds = None
1653-
1654- # Save model dictionary as pickle file
1718+ # Save model dictionary as pickle file in the model directory
16551719 model_data_filename = "model_data.pkl"
16561720 model_data_file = model_name_full_path / model_data_filename
16571721 with open (model_data_file , "wb" ) as f :
16581722 pickle .dump (model , f )
16591723
16601724 file_paths = [checkpoint_file , model_data_file ]
16611725
1726+ completion_time = datetime .now (timezone .utc )
1727+ duration_seconds = (
1728+ (completion_time - execution_time ).total_seconds ()
1729+ if task_mode == "trigger"
1730+ else None
1731+ )
1732+
16621733 self .insert1 (
16631734 {
16641735 ** key ,
@@ -1676,21 +1747,19 @@ def make(self, key):
16761747 {
16771748 ** key ,
16781749 "file_name" : file .name ,
1679- "file " : file .as_posix (),
1750+ "file_path " : file .as_posix (),
16801751 }
16811752 for file in file_paths
16821753 ]
16831754 )
16841755
1685- # Insert config file
16861756 self .ConfigFile .insert1 (
16871757 {
16881758 ** key ,
1689- "config_file" : kpms_dj_config_path ,
1759+ "config_file" : kpms_dj_config_abs_path ,
16901760 }
16911761 )
16921762
1693- # Insert plots
16941763 self .Plots .insert1 (
16951764 {
16961765 ** key ,
@@ -1711,7 +1780,6 @@ class ModelScore(dj.Computed):
17111780 -> FullFit
17121781 ---
17131782 score=NULL : float # Model score (MLL for single model)
1714- std_error=NULL : float # Standard error of the model score
17151783 """
17161784
17171785 def make (self , key ):
@@ -1722,28 +1790,24 @@ def make(self, key):
17221790 # Get checkpoint file for this specific model
17231791 checkpoint_file = (
17241792 FullFit .File & key & 'file_name LIKE "%checkpoint.h5"'
1725- ).fetch1 ("file " )
1793+ ).fetch1 ("file_path " )
17261794
17271795 # Load the checkpoint to get model data
17281796 model , data , _ , _ = load_checkpoint (path = checkpoint_file )
17291797
1730- # Compute marginal log likelihood for this model
1798+ # Compute marginal log likelihood for single model
17311799 mask = jnp .array (data ["mask" ])
17321800 x = jnp .array (model ["states" ]["x" ])
17331801 Ab = jnp .array (model ["params" ]["Ab" ])
17341802 Q = jnp .array (model ["params" ]["Q" ])
17351803 pi = jnp .array (model ["params" ]["pi" ])
1736-
1737- # Compute marginal log likelihood - this is the correct metric for single models
17381804 mll = marginal_log_likelihood (mask , x , Ab , Q , pi )
17391805 score = float (mll ) # Store as "score" - this is MLL
1740- std_error = 0.0 # No standard error for single model MLL
17411806
17421807 self .insert1 (
17431808 {
17441809 ** key ,
17451810 "score" : score ,
1746- "std_error" : std_error ,
17471811 }
17481812 )
17491813
0 commit comments