Skip to content

Commit 848238a

Browse files
committed
docs: update readme references and add modular trajectory harvester
1 parent ef707cc commit 848238a

4 files changed

Lines changed: 125 additions & 21 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ wandb/
4747
.pytest_cache/
4848
.coverage
4949
htmlcov/
50+
tests/artifacts/
5051

5152
# Streamlit
5253
.streamlit/

README.md

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
DT-Circuits is a research framework for mechanistic interpretability of Decision Transformers, focused on causal analysis, sparse feature decomposition, and circuit-level understanding of sequential decision-making agents.
1010

11+
**Live Interactive Demo:** [DT-Explorer on Hugging Face Spaces](https://huggingface.co/spaces/sadhumitha-s/DT-Explorer)
12+
1113
---
1214

1315
## Table of Contents
@@ -17,6 +19,7 @@ DT-Circuits is a research framework for mechanistic interpretability of Decision
1719
- [Project Structure](#project-structure)
1820
- [Installation and Usage](#installation-and-usage)
1921
- [Documentation](#documentation)
22+
- [Foundational Research & References](#foundational-research--references)
2023
- [Citation](#citation)
2124
- [License](#license)
2225

@@ -167,49 +170,72 @@ sae:
167170
168171
---
169172
170-
## Installation and Usage
173+
## Execution Modes: Installation and Usage
174+
175+
There are two primary ways to run and interact with the **DT-Circuits** framework depending on your research needs:
176+
177+
---
178+
179+
### Way 1: Interactive Cloud Demo (Hugging Face Spaces)
180+
181+
For instant visual exploration, path intervention, and alignment auditing without any local workspace preparation, launch the web dashboard directly:
182+
183+
* **Demo Link:** [DT-Explorer on Hugging Face Spaces](https://huggingface.co/spaces/sadhumitha-s/DT-Explorer)
184+
185+
> [!NOTE]
186+
> **Concise Demo Constraints:**
187+
> * **CPU-Bound Resources:** Runs on standard free-tier CPU instances (2 vCPUs, 16 GB RAM); high-overhead operations like ACDC scans may show higher latency than on a local GPU workspace.
188+
> * **Slices Dataset:** Trajectory datasets are dynamically sliced down to a lightweight demo set under a **10MB limit** (defined in [deploy.sh](file:///Users/sadhumitha/Documents/projects/DT-Circuits/scripts/deploy.sh#L19-L33)) for storage and memory footprint constraints.
189+
> * **Read-Only / Ephemeral Container:** Uses pre-baked static weights (`mini_dt.pt`) and pre-trained SAE checkpoints. Training new models or writing persistent states is disabled.
190+
191+
---
192+
193+
### Way 2: Clone and Run Locally (Full Pipeline)
194+
195+
For full end-to-end research, customized hyperparameter tuning, local data harvesting, and GPU-accelerated model or SAE training, run the workspace on your machine.
171196

172-
### Setup
197+
#### Local Environment Setup
198+
First, clone the repository, set up a virtual environment, and install dependencies:
173199
```bash
200+
git clone https://github.com/sadhumitha-s/DT-Circuits
201+
cd DT-Circuits
202+
174203
python -m venv venv
175204
source venv/bin/activate
205+
176206
pip install -r requirements.txt
177207
```
178208

179-
### Dashboard Execution
180-
You can access the hosted version on Hugging Face Spaces instantly, or run it locally:
209+
#### Option 2.1: Simple Workflows via Makefile
210+
The workspace includes a standardized [Makefile](file:///Users/sadhumitha/Documents/projects/DT-Circuits/Makefile) to orchestrate common research pipelines with single commands:
181211

182-
* **Live Hosted Space:** [DT-Explorer Web App](https://sadhumitha-s-dt-explorer.hf.space) (No local installation needed!)
183-
* **Local Run:** Launch the dashboard on your machine (it will initialize with a random model if no trained weights are detected):
184-
```bash
185-
streamlit run src/dashboard/app.py
186-
```
212+
```bash
213+
make setup # Set up local environment & install requirements
214+
make train # Run the full end-to-end pipeline (Data harvesting -> DT -> SAE training)
215+
make dashboard # Run the Streamlit visualization dashboard locally
216+
```
187217

188-
### Workflow
218+
#### Option 2.2: Granular Control via Bash & Python
219+
For research flexibility, execute each step of the pipeline manually using granular terminal scripts:
189220

190-
1. **Data Harvesting & Model Training**
221+
1. **Trajectories & Model Training**
222+
Harvest teacher trajectories and train the target Decision Transformer (`HookedDT`):
191223
```bash
192224
python scripts/train_dt.py
193225
```
194226

195-
2. **SAE Training**
227+
2. **TopK Sparse Autoencoder (SAE) Training**
228+
Train sparse autoencoders on target activation layers:
196229
```bash
197230
python scripts/train_sae.py
198231
```
199232

200-
3. **Interpretability Analysis**
233+
3. **Interactive Analysis**
234+
Launch the Streamlit visualization engine locally to run audits with custom weights:
201235
```bash
202236
streamlit run src/dashboard/app.py
203237
```
204238

205-
### Alternative: Makefile
206-
Common tasks can also be executed via `make`:
207-
```bash
208-
make setup # Install dependencies
209-
make train # Run full training pipeline (DT + SAE)
210-
make dashboard # Launch DT-Explorer
211-
```
212-
213239
---
214240

215241
## Documentation
@@ -222,6 +248,19 @@ Detailed technical documentation for specific modules:
222248

223249
---
224250

251+
## Foundational Research & References
252+
253+
This framework implements and builds upon the following foundational methodologies:
254+
255+
* **Decision Transformers**: [Chen et al., 2021](https://arxiv.org/abs/2106.01345) — Reinforcement learning as sequence modeling.
256+
* **Transformer Circuits**: [Elhage et al., 2021](https://transformer-circuits.pub/2021/framework/index.html) — Mathematical foundations of mechanistic interpretability.
257+
* **ACDC (Automated Circuit Discovery)**: [Conmy et al., 2023](https://arxiv.org/abs/2304.14997) — Algorithmic discovery of subgraphs.
258+
* **Sparse Autoencoders (SAEs)**: [Bricken et al., 2023](https://transformer-circuits.pub/2023/monosemantic-features/index.html) (monosemantic features) & [Gao et al., 2024](https://arxiv.org/abs/2406.04096) (TopK SAEs).
259+
* **Activation Steering**: [Turner et al., 2023](https://arxiv.org/abs/2308.10248) — Control via residual stream vector additions.
260+
* **Path Patching**: [Goldowsky-Dill et al., 2023](https://arxiv.org/abs/2304.05969) — Inter-component causal mediation.
261+
262+
---
263+
225264
## Citation
226265

227266
```bibtex

src/data/__init__.py

Whitespace-only changes.

src/data/harvester.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import os
2+
import gymnasium as gym
3+
import torch
4+
import numpy as np
5+
from minigrid.wrappers import FlatObsWrapper
6+
from stable_baselines3 import PPO
7+
from tqdm import tqdm
8+
9+
class PPOHarvester:
10+
"""
11+
Utility to run a 'Teacher' PPO agent to collect high-quality state-action-reward triplets.
12+
"""
13+
def __init__(self, env_id="MiniGrid-Empty-8x8-v0", model_path=None):
14+
self.env_id = env_id
15+
self.env = FlatObsWrapper(gym.make(env_id, render_mode="rgb_array"))
16+
if model_path and os.path.exists(model_path):
17+
self.model = PPO.load(model_path, env=self.env)
18+
else:
19+
print(f"No model found at {model_path}. Training a new one for collection...")
20+
self.model = PPO("MlpPolicy", self.env, verbose=1)
21+
self.model.learn(total_timesteps=20000)
22+
if model_path:
23+
self.model.save(model_path)
24+
25+
def collect_trajectories(self, num_episodes=100):
26+
trajectories = []
27+
for i in tqdm(range(num_episodes), desc="Collecting trajectories"):
28+
obs, _ = self.env.reset(seed=42 + i)
29+
done = False
30+
truncated = False
31+
episode = {
32+
"observations": [],
33+
"actions": [],
34+
"rewards": [],
35+
"dones": []
36+
}
37+
while not (done or truncated):
38+
action, _states = self.model.predict(obs, deterministic=False)
39+
next_obs, reward, done, truncated, info = self.env.step(action)
40+
41+
episode["observations"].append(obs)
42+
episode["actions"].append(action)
43+
episode["rewards"].append(reward)
44+
episode["dones"].append(done)
45+
46+
obs = next_obs
47+
48+
# Convert to numpy arrays
49+
for key in episode:
50+
episode[key] = np.array(episode[key])
51+
52+
trajectories.append(episode)
53+
54+
return trajectories
55+
56+
def save_trajectories(self, trajectories, file_path):
57+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
58+
torch.save(trajectories, file_path)
59+
print(f"Saved {len(trajectories)} trajectories to {file_path}")
60+
61+
if __name__ == "__main__":
62+
harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
63+
trajs = harvester.collect_trajectories(num_episodes=50)
64+
harvester.save_trajectories(trajs, "data/trajectories.pt")

0 commit comments

Comments
 (0)