Skip to content

AllenNeuralDynamics/aind-disrnn-dispatcher

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

44 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

aind-disrnn-dispatcher

The dispatcher capsule in the AIND-disRNN MLOps stack

image

Usage

This capsule uses Hydra to compose configurations from the files under code/config/. It supports arbitrary combinations of data sources (mice or synthetic) and models (disRNN or other baselines), and allows flexible parameter overrides or parameter sweeps. The resulting job specifications will be distributed to downstream wrapper capsule for parallel training.

Examples (direct Python invocation)

  1. Default single run with the default parameters (data=mice, model=disrnn)

    Uses the defaults declared in config/config.yaml:

    python code/run_capsule.py

    Explicit form:

    python code/run_capsule.py data=mice model=disrnn

    Here is the output

    {
        "data": {
            "source": "mice",
            "subject_ids": [
                774212
            ],
            "multisubject": false,
            "ignore_policy": "exclude",
            "features": {
                "animal_response": "prev choice",
                "rewarded": "prev reward"
            }
        },
        "model": {
            "type": "disrnn",
            "architecture": {
                "latent_size": 5,
                "update_net_n_units_per_layer": 16,
                "update_net_n_layers": 8,
                "choice_net_n_units_per_layer": 4,
                "choice_net_n_layers": 1,
                "activation": "leaky_relu"
            },
            "penalties": {
                "beta": 0.01,
                "latent_penalty": 0.01,
                "choice_net_latent_penalty": 0.01,
                "update_net_obs_penalty": 0.01,
                "update_net_latent_penalty": 0.01
            },
            "training": {
                "n_steps": 3000,
                "n_warmup_steps": 1000,
                "lr": 0.001,
                "eval_every_n": 100,
                "loss": "penalized_categorical",
                "loss_param": 1.0
            }
        },
        "job_id": 0,
        "seed": 42,
        "wandb": {
            "entity": "AIND-disRNN"
        }
    }
  2. Apply baseline RL model on a synthetic dataset generated by an RL agent in an Uncoupled task

    python code/run_capsule.py data=synthetic data.synthetic.agent=rl data.synthetic.task=uncoupled_block model=baseline_rl

    The output

    {
        "data": {
            "synthetic": {
                "task": "uncoupled_block",
                "agent": "rl"
            },
            "source": "synthetic",
            "num_trials": 1000
        },
        "model": {
            "type": "baseline_rl",
            "agent_class": "ForagerQLearning",
            "agent_kwargs": {
                "number_of_learning_rate": 2,
                "number_of_forget_rate": 1,
                "choice_kernel": "none",
                "action_selection": "softmax"
            },
            "fit_kwargs": {
                "DE_kwargs": {
                    "polish": true,
                    "seed": 42
                }
            }
        },
        "job_id": 0,
        "seed": 42,
        "wandb": {
            "entity": "AIND-disRNN"
        }
    }
  3. Hydra multirun sweep over model penalties beta and training learning rate

    Use -m (or --multirun) to launch a Cartesian product of overrides:

    python code/run_capsule.py -m model.penalties.beta=0.0001,0.001,0.01 model.training.lr=0.0001,0.001

    Results will be placed under results/ using Hydra-generated job IDs.

Additional notes

  • All composed configurations are serialized to JSON at .hydra/config.json inside the run directory.
  • To call programmatically from Python:
    from run_capsule import generate_jobs_with_args
    generate_jobs_with_args(["data=mice", "model=disrnn", "job_id=42"])

Usage in Code Ocean

  • Enter the whole override string in the app panel and hit "Run"

    image
  • You'll see all generated jobs like this

    image