Skip to content

Commit c941c06

Browse files
committed
test and type hints
1 parent 8735dc3 commit c941c06

File tree

7 files changed

+1749
-212
lines changed

7 files changed

+1749
-212
lines changed

CLAUDE.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
Polychrom is an Open2C polymer simulation library designed to build mechanistic models of chromosomes. It simulates biological processes subject to forces or constraints, which are then compared to Hi-C maps, microscopy, and other data sources.
8+
9+
## Development Commands
10+
11+
### Installation
12+
```bash
13+
pip install cython # Required dependency
14+
pip install -r requirements.txt
15+
pip install -e . # Install in development mode
16+
```
17+
18+
### Testing
19+
```bash
20+
pytest # Run all tests
21+
```
22+
23+
### Building Cython Extensions
24+
```bash
25+
python setup.py build_ext --inplace
26+
```
27+
28+
## Architecture
29+
30+
### Core Components
31+
32+
**Simulation Module** (`polychrom/simulation.py`)
33+
- Central `Simulation` class manages the entire simulation lifecycle
34+
- Handles platform setup (CUDA/OpenCL/CPU), integrators, and parameters
35+
- Methods: `set_data()` for loading conformations, `add_force()` for adding forces, `doBlock()` for running simulation steps
36+
37+
**Forces System** (`polychrom/forces.py`, `polychrom/forcekits.py`)
38+
- Forces define polymer behavior: connectivity, confinement, crosslinks, tethering
39+
- Individual forces in `forces.py` are functions that create OpenMM force objects
40+
- Complex force combinations are packaged as "forcekits" (e.g., `polymer_chains`)
41+
- Legacy forces available in `polychrom/legacy/forces.py`
42+
43+
**Data Storage** (`polychrom/hdf5_format.py`)
44+
- HDF5Reporter handles simulation output in HDF5 format
45+
- Backwards compatibility with legacy format via legacy reporter
46+
- `polymerutils.load()` function reads both new and old formats
47+
48+
**Starting Conformations** (`polychrom/starting_conformations.py`)
49+
- Functions to generate initial polymer configurations
50+
- Example: `grow_cubic()` creates a cubic lattice conformation
51+
52+
### Key Design Patterns
53+
54+
1. **Force Architecture**: Forces are simple functions that wrap OpenMM force objects, returning the force with a `.name` attribute
55+
2. **Simulation Flow**: Initialize Simulation → Load data → Add forces → Run blocks in loop → Save via reporter
56+
3. **Extensibility**: Users can define custom forces in their scripts following the pattern in `forces.py`
57+
58+
## Important Notes
59+
60+
- OpenMM is the underlying engine (required dependency not in requirements.txt)
61+
- Cython extensions in `_polymer_math.pyx` require compilation
62+
- Main use case is loop extrusion simulations (see `examples/loopExtrusion/`)
63+
- Testing uses pytest with configuration in `pytest.ini`

polychrom/hdf5_format.py

Lines changed: 130 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -95,36 +95,64 @@
9595
import glob
9696
import os
9797
import warnings
98+
from typing import Dict, List, Tuple, Optional, Union, Any
9899

99100
import h5py
100101
import numpy as np
101102

102-
DEFAULT_OPTS = {"compression_opts": 9, "compression": "gzip"}
103+
DEFAULT_OPTS: Dict[str, Union[int, str]] = {"compression_opts": 9, "compression": "gzip"}
103104

104105

105-
def _read_h5_group(gr):
106+
def _read_h5_group(gr: h5py.Group) -> Dict[str, Any]:
106107
"""
107108
Reads all attributes of an HDF5 group, and returns a dict of them
109+
110+
Parameters
111+
----------
112+
gr : h5py.Group
113+
HDF5 group to read from
114+
115+
Returns
116+
-------
117+
Dict[str, Any]
118+
Dictionary containing all datasets and attributes from the group
108119
"""
109120
result = {i: j[:] for i, j in gr.items()}
110121
for i, j in gr.attrs.items():
111-
result[i] = j
122+
# Convert bytes to string if it's a bytes string
123+
if isinstance(j, bytes):
124+
try:
125+
result[i] = j.decode('utf-8')
126+
except UnicodeDecodeError:
127+
result[i] = j
128+
else:
129+
result[i] = j
112130
return result
113131

114132

115-
def _convert_to_hdf5_array(data):
133+
def _convert_to_hdf5_array(data: Any) -> Tuple[Optional[str], Optional[Union[np.ndarray, Any]]]:
116134
"""
117135
Attempts to convert data to HDF5 compatible array
118136
or to an HDF5 attribute compatible entity (str, number)
119137
120138
Does its best at determining if this is a "normal"
121139
object (str, int, float), or an array.
122140
123-
Right now, if something got converted to a numpy object,
141+
Right now, if something got converted to a numpy object dtype,
124142
it is discarded and not saved in any way.
125143
We could think about pickling those cases, or JSONing them...
144+
145+
Parameters
146+
----------
147+
data : Any
148+
Data to convert to HDF5-compatible format
149+
150+
Returns
151+
-------
152+
Tuple[Optional[str], Optional[Union[np.ndarray, Any]]]
153+
Tuple of (datatype, converted_data) where datatype is "item", "ndarray", or None
126154
"""
127-
if type(data) == str:
155+
if isinstance(data, str):
128156
data = np.array(data, dtype="S")
129157
data = np.array(data)
130158

@@ -137,29 +165,45 @@ def _convert_to_hdf5_array(data):
137165
return "ndarray", data
138166

139167

140-
def _write_group(dataDict, group, dset_opts=None):
168+
def _write_group(
169+
dataDict: Dict[str, Any],
170+
group: h5py.Group,
171+
dset_opts: Optional[Dict[str, Any]] = None
172+
) -> None:
141173
"""
142174
Writes a dictionary of elements to an HDF5 group
143175
Puts all "items" into attrs, and all ndarrays into datasets
144176
145-
dset_opts is a dictionary of arguments passed to create_dataset function
146-
(compression would be here for example). By default set to DEFAULT_OPTS
177+
Parameters
178+
----------
179+
dataDict : Dict[str, Any]
180+
Dictionary of data to write
181+
group : h5py.Group
182+
HDF5 group to write to
183+
dset_opts : Optional[Dict[str, Any]]
184+
Dictionary of arguments passed to create_dataset function
185+
(compression would be here for example). By default set to DEFAULT_OPTS
147186
"""
148187
if dset_opts is None:
149188
dset_opts = DEFAULT_OPTS
150189
for name, data in dataDict.items():
151190
datatype, converted = _convert_to_hdf5_array(data)
152191
if datatype is None:
153-
warnings.warn(f"Could not convert record {name}")
192+
warnings.warn(f"Could not convert record {name} of type {type(data)}")
154193
elif datatype == "item":
155-
group.attrs[name] = data
194+
group.attrs[name] = converted # Use converted instead of data
156195
elif datatype == "ndarray":
157-
group.create_dataset(name, data=data, **dset_opts)
196+
group.create_dataset(name, data=converted, **dset_opts) # Use converted
158197
else:
159-
raise ValueError("Unknown datatype")
198+
raise ValueError(f"Unknown datatype: {datatype}")
160199

161200

162-
def list_URIs(folder, empty_error=True, read_error=True, return_dict=False):
201+
def list_URIs(
202+
folder: str,
203+
empty_error: bool = True,
204+
read_error: bool = True,
205+
return_dict: bool = False
206+
) -> Union[List[str], Dict[int, str]]:
163207
"""
164208
Makes a list of URIs (path-like records for each block). for a trajectory folder
165209
Now we store multiple blocks per file, and URI is a
@@ -206,8 +250,9 @@ def list_URIs(folder, empty_error=True, read_error=True, return_dict=False):
206250
except Exception:
207251
if read_error:
208252
raise ValueError(f"Cannot read file {file}")
209-
sted = os.path.split(file)[-1].split("_")[1].split(".h5")[0]
210-
st, end = [int(i) for i in sted.split("-")]
253+
# Extract start and end block numbers from filename like "blocks_1-50.h5"
254+
filename_parts = os.path.basename(file).split("_")[1].split(".h5")[0]
255+
st, end = [int(i) for i in filename_parts.split("-")]
211256
for i in range(st, end + 1):
212257
if i in filenames:
213258
raise ValueError(f"Block {i} exists more than once")
@@ -218,49 +263,83 @@ def list_URIs(folder, empty_error=True, read_error=True, return_dict=False):
218263
return {int(i[0]): i[1] for i in sorted(filenames.items(), key=lambda x: int(x[0]))}
219264

220265

221-
def load_URI(dset_path):
266+
def load_URI(dset_path: str) -> Dict[str, Any]:
222267
"""
223-
Loads a single block of the simulation using address provided by list_filenames
224-
dset_path should be
268+
Loads a single block of the simulation using address provided by list_URIs
225269
226-
/path/to/trajectory/folder/blocks_X-Y.h5::Z
227-
228-
where Z is the block number
270+
Parameters
271+
----------
272+
dset_path : str
273+
Path in format: /path/to/trajectory/folder/blocks_X-Y.h5::Z
274+
where Z is the block number
275+
276+
Returns
277+
-------
278+
Dict[str, Any]
279+
Dictionary containing the block data
229280
"""
281+
if "::" not in dset_path:
282+
raise ValueError(f"Invalid URI format: {dset_path}. Expected format: filename.h5::block_number")
230283

231284
fname, group = dset_path.split("::")
232285
with h5py.File(fname, mode="r") as myfile:
233286
return _read_h5_group(myfile[group])
234287

235288

236-
def save_hdf5_file(filename, data_dict, dset_opts=None, mode="w"):
289+
def save_hdf5_file(
290+
filename: str,
291+
data_dict: Dict[str, Any],
292+
dset_opts: Optional[Dict[str, Any]] = None,
293+
mode: str = "w"
294+
) -> None:
237295
"""
238296
Saves data_dict to filename
297+
298+
Parameters
299+
----------
300+
filename : str
301+
Path to the HDF5 file to save
302+
data_dict : Dict[str, Any]
303+
Dictionary of data to save
304+
dset_opts : Optional[Dict[str, Any]]
305+
Options for dataset creation (e.g., compression)
306+
mode : str
307+
File opening mode (default "w")
239308
"""
240309
if dset_opts is None:
241310
dset_opts = DEFAULT_OPTS
242311
with h5py.File(filename, mode=mode) as file:
243312
_write_group(data_dict, file, dset_opts=dset_opts)
244313

245314

246-
def load_hdf5_file(fname):
315+
def load_hdf5_file(fname: str) -> Dict[str, Any]:
247316
"""
248-
Loads a saved HDF5 files, reading all datasets and attributes.
317+
Loads a saved HDF5 file, reading all datasets and attributes.
249318
We save arrays as datasets, and regular types as attributes in HDF5
319+
320+
Parameters
321+
----------
322+
fname : str
323+
Path to the HDF5 file to load
324+
325+
Returns
326+
-------
327+
Dict[str, Any]
328+
Dictionary containing all data from the file
250329
"""
251330
with h5py.File(fname, mode="r") as myfile:
252331
return _read_h5_group(myfile)
253332

254333

255-
class HDF5Reporter(object):
334+
class HDF5Reporter:
256335
def __init__(
257336
self,
258-
folder,
259-
max_data_length=50,
260-
h5py_dset_opts=None,
261-
overwrite=False,
262-
blocks_only=False,
263-
check_exists=True,
337+
folder: str,
338+
max_data_length: int = 50,
339+
h5py_dset_opts: Optional[Dict[str, Any]] = None,
340+
overwrite: bool = False,
341+
blocks_only: bool = False,
342+
check_exists: bool = True,
264343
):
265344
"""
266345
Creates a reporter object that saves a trajectory to a folder
@@ -289,23 +368,23 @@ def __init__(
289368

290369
if h5py_dset_opts is None:
291370
h5py_dset_opts = DEFAULT_OPTS
292-
self.prefixes = [
371+
self.prefixes: List[str] = [
293372
"blocks",
294373
"applied_forces",
295374
"initArgs",
296375
"starting_conformation",
297376
"energy_minimization",
298377
"forcekit_polymer_chains",
299378
] # these are used for inferring if a file belongs to a trajectory or not
300-
self.counter = {} # initializing all the options and dictionaries
301-
self.datas = {}
302-
self.max_data_length = max_data_length
303-
self.h5py_dset_opts = h5py_dset_opts
304-
self.folder = folder
305-
self.blocks_only = blocks_only
379+
self.counter: Dict[str, int] = {} # initializing all the options and dictionaries
380+
self.datas: Dict[int, Dict[str, Any]] = {}
381+
self.max_data_length: int = max_data_length
382+
self.h5py_dset_opts: Dict[str, Any] = h5py_dset_opts
383+
self.folder: str = folder
384+
self.blocks_only: bool = blocks_only
306385

307386
if not os.path.exists(folder):
308-
os.mkdir(folder)
387+
os.makedirs(folder, exist_ok=True)
309388

310389
if overwrite:
311390
for the_file in os.listdir(folder):
@@ -316,7 +395,8 @@ def __init__(
316395
os.remove(file_path)
317396
else:
318397
raise IOError(
319-
"Subfolder in traj folder; not deleting. Ensure folder is " "correct and delete manually. "
398+
f"Subfolder {file_path} in traj folder; not deleting. "
399+
"Ensure folder is correct and delete manually."
320400
)
321401

322402
if check_exists:
@@ -326,7 +406,11 @@ def __init__(
326406
if the_file.startswith(prefix):
327407
raise RuntimeError(f"folder {folder} is not empty: set check_exists=False to ignore")
328408

329-
def continue_trajectory(self, continue_from=None, continue_max_delete=5):
409+
def continue_trajectory(
410+
self,
411+
continue_from: Optional[int] = None,
412+
continue_max_delete: int = 5
413+
) -> Tuple[int, Dict[str, Any]]:
330414
"""
331415
Continues a simulation in a current folder (i.e. continues from the last block, or the block you specify).
332416
By default, takes the last block. Otherwise, takes the continue_from block
@@ -376,7 +460,7 @@ def continue_trajectory(self, continue_from=None, continue_max_delete=5):
376460

377461
todelete = np.nonzero(uri_inds >= continue_from)[0]
378462
if len(todelete) > continue_max_delete:
379-
raise ValueError("Refusing to delete {uris_delete} blocks - set continue_max_delete accordingly")
463+
raise ValueError(f"Refusing to delete {len(todelete)} blocks - set continue_max_delete accordingly")
380464

381465
fnames_delete = np.unique(uri_fnames[todelete])
382466
inds_tosave = np.nonzero((uri_fnames == uri_fnames[ind]) * (uri_inds <= ind))[0]
@@ -405,7 +489,7 @@ def continue_trajectory(self, continue_from=None, continue_max_delete=5):
405489

406490
return uri_inds[ind], newdata
407491

408-
def report(self, name, values):
492+
def report(self, name: str, values: Dict[str, Any]) -> None:
409493
"""
410494
Semi-internal method to be called when you need to report something
411495
@@ -434,7 +518,8 @@ def report(self, name, values):
434518
self.dump_data()
435519
self.counter[name] = count + 1
436520

437-
def dump_data(self):
521+
def dump_data(self) -> None:
522+
"""Writes accumulated block data to disk and clears the buffer"""
438523
if len(self.datas) > 0:
439524
cmin = min(self.datas.keys())
440525
cmax = max(self.datas.keys())

0 commit comments

Comments
 (0)