Skip to content

Commit a8ed43b

Browse files
authored
Merge pull request #10 from andrewklayk/main
A lovely piece of work!
2 parents 83f3b34 + f9547a1 commit a8ed43b

122 files changed

Lines changed: 11763 additions & 17339 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
# Dataset and saved models
2-
32
experiments/utils/raw_data/
43
experiments/utils/exp_results
54
experiments/utils/saved_models
65
experiments/outputs
76
experiments/conf
8-
data/
7+
benchmark/data
98
examples/data
109
experiments/data
10+
.hydra
1111
.vscode/
1212
plots/
1313
outputs/
14-
14+
rci_jupyter_setup.txt
15+
requirements_rci.txt
16+
*.csv
17+
benchmark/results
18+
benchmark/cache
19+
benchmark/data
1520

1621
# Byte-compiled / optimized / DLL files
1722
__pycache__/

README.md

Lines changed: 72 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ The toolkit implements algorithms for constrained training of neural networks ba
1010
1. [Basic installation instructions](#basic-installation-instructions)
1111
2. [Using the toolkit](#using-the-toolkit)
1212
3. [Extending the toolkit](#extending-the-toolkit)
13-
4. [Reproducing the Benchmark](#reproducing-the-benchmark)
14-
5. [License and terms of use](#license-and-terms-of-use)
15-
6. [References](#references)
13+
4. [License and terms of use](#license-and-terms-of-use)
14+
5. [References](#references)
1615

1716
humancompatible-train is still under active development! If you find bugs or have feature
1817
requests, please file a
@@ -32,28 +31,32 @@ The only dependencies of this package are `numpy` and `torch`.
3231

3332
The toolkit implements algorithms for constrained training of neural networks based on PyTorch.
3433

35-
The algorithms follow the `dual_step()` - `step()` framework: taking inspiration from PyTorch, the `dual_step()` does updates related to the dual parameters and prepares for the primal update (by, e.g., saving constraint gradients), and `step()` updates the primal parameters.
34+
The algorithms are intended for use in tandem with classic PyTorch optimizers, calculating the Lagrangian and keeping track of the dual variables.
35+
36+
<!-- The algorithms follow the `dual_step()` - `step()` framework: taking inspiration from PyTorch, the `dual_step()` does updates related to the dual parameters and prepares for the primal update (by, e.g., saving constraint gradients), and `step()` updates the primal parameters. -->
3637

3738
In general, your code using `humancompatible-train` would look something like this:
3839

3940
```python
41+
optimizer = torch.optim.Adam(model.parameters(), ...)
42+
dual_optimizer = humancompatible.train.dual_optim.ALM(...)
43+
4044
for inputs, labels in dataloader:
41-
# inference
45+
# evaluate objective
4246
outputs = model(inputs)
43-
# calculate constraints and grads
44-
for constraint in constraints:
45-
c_eval = constraint(outputs, labels)
46-
c_eval.backwards(retain_grad=True)
47-
# depending on optimizer, update dual parameters / save constraint gradient / both
48-
optimizer.dual_step(c_eval)
49-
optimizer.zero_grad()
50-
# calculate objective
51-
loss = criterion(outputs,labels)
52-
loss.backwards()
47+
loss = criterion(outputs, labels)
48+
# evaluate tensor of constraints
49+
constraints = <eval_your_constraints>(inputs, labels)
50+
# evaluate lagrangian and update dual variables
51+
lgr = dual_optimizer.forward_update(loss, constraints)
52+
# backward pass and step
53+
lgr.backward()
5354
optimizer.step()
5455
optimizer.zero_grad()
5556
```
5657

58+
The key difference is calculating the lagrangian using **`lgr = forward_update(loss, constraints)`**, and then running **`lgr.backward()`** instead of `loss.backward()`.
59+
5760
Our idea is to
5861

5962
1. Deviate minimally from the usual PyTorch workflow
@@ -63,22 +66,70 @@ Our idea is to
6366

6467
You are invited to check out our new API presented in notebooks in the `examples` folder.
6568

66-
The example notebooks have additional dependencies, such as `fairret`. To install those, run
69+
The example notebooks have additional dependencies for data and plotting, such as `fairret`. To install those, run
6770

6871
```
6972
pip install humancompatible-train[examples]
7073
```
7174

72-
*The legacy API used for the benchmark is presented in `examples/_old_/algorithm_demo.ipynb` and `examples/_old_/constraint_demo.ipynb`.*
73-
7475
## Extending the toolkit
7576

76-
### Adding new code
77-
7877
**To add a new algorithm**, you can subclass the PyTorch ```Optimizer``` class and proceed following the API guideline presented above.
7978

8079
## Reproducing the Benchmark
8180

81+
The code for benchmarking constrained regularization algorithms is available in the `benchmark` directory.
82+
83+
### Installation instructions
84+
85+
1. Create a virtual environment
86+
87+
**bash** (Linux)
88+
89+
```
90+
python3.11 -m venv fairbenchenv
91+
source fairbenchenv/bin/activate
92+
```
93+
94+
**cmd** (Windows)
95+
96+
```
97+
python -m venv fairbenchenv
98+
fairbenchenv\Scripts\activate.bat
99+
```
100+
101+
2. Install from source.
102+
103+
```
104+
git clone https://github.com/humancompatible/train.git
105+
cd train
106+
pip install -r requirements.txt
107+
pip install .
108+
```
109+
110+
### Usage instructions
111+
112+
The benchmark offers two families of datasets: Folktables and Dutch, several pre-defined constraints, and several constrained optimization algorithms: `ALM` (smoothed and non-smoothed), `SPBM`, and Switching Subgradient; we are currently working to add Stochastic Ghost within the new framework as well.
113+
114+
To run an experiment, run:
115+
116+
```
117+
python run_benchmark.py --dataset <DATASET> [folktables, dutch] --task <TYPE OF CONSTRAINT> [loss, equalized_odds_pairwise, equalized_odds_vec, weight_norm] --n_runs <NUMBER OF RUNS OF EACH METHOD> --n_epochs <NUMBER OF EPOCHS PER RUN>
118+
```
119+
120+
The constraint options are:
121+
122+
- `loss`: constraint(s) on the absolute difference between the classification loss on each group and the overall classification loss;
123+
- `equalized_odds_pairwise`: constraint(s) on the absolute difference between the positive rate between each group;
124+
- `equalized_odds_vec`: constraint on the Positive Rate of each group as defined by `fairret.NormLoss`;
125+
- `weight_norm`: constraint on the Frobenius norm of the weights and biases of each layer of the neural network.
126+
127+
The benchmarking code (all of which is contained in the `benchmark` directory) is easy to parse and extend with other datasets and constraints.
128+
129+
130+
<!--
131+
## Reproducing the Benchmark
132+
82133
The code used in [our benchmark paper](https://arxiv.org/abs/2507.04033) is not migrated to the new API yet (WIP).
83134
84135
### Basic installation instructions
@@ -122,11 +173,6 @@ pip install -e .
122173
123174
after installing requirements.txt; otherwise, the algorithm will run slower. However, this is not supported on MacOS and may fail on some Windows devices.
124175
125-
<!-- Install via pip -->
126-
<!-- ``` -->
127-
<!-- pip install folktables -->
128-
<!-- ``` -->
129-
130176
### Running the algorithms
131177
132178
The benchmark comprises the following algorithms:
@@ -149,19 +195,6 @@ python run_folktables.py data=folktables alg=fairret # baseline, fairness with r
149195
150196
Each command will start 10 runs of the `alg`, 30 seconds each.
151197
The results will be saved to `experiments/utils/saved_models` and `experiments/utils/exp_results`.
152-
<!-- In the repository, we include the configuration needed to reproduce the experiments in the paper. To do so, go to `experiments` and run `python run_folktables.py data=folktables alg=sslalm`. -->
153-
<!-- Repeat for the other algorithms by changing the `alg` parameter. -->
154-
155-
This repository uses [Hydra](https://hydra.cc/) to manage parameters; see `experiments/conf` for configuration files.
156-
157-
- To change the parameters of the experiment, such as the number of runs for each algorithm, run time, the dataset used (*note: for now supports only Folktables*) - use `experiment.yaml`.
158-
- To change the dataset settings - such as file location - or do dataset-specific adjustments - such as the configuration of the protected attributes - use `data/{dataset_name}.yaml`
159-
- To change algorithm hyperparameters, use `alg/{algorithm_name}.yaml`.
160-
- To change constraint hyperparameters, use `constraint/{constraint_name}.yaml`
161-
162-
<!-- ; it is installed as one of the dependencies. -->
163-
<!-- To learn more about using Hydra, please check out the [official tutorial](https://hydra.cc/docs/tutorials/basic/your_first_app). -->
164-
165198
### Producing plots
166199
167200
The plots and tables like the ones in the paper can be produced using the two notebooks. `experiments/algo_plots.ipynb` houses the convergence plots, and `experiments/model_plots.ipynb` - all the others.
@@ -176,25 +209,14 @@ It provides code to download data from the American Community Survey
176209
The data itself is governed by the terms of use provided by the Census Bureau.
177210
For more information, see <https://www.census.gov/data/developers/about/terms-of-service.html>
178211
179-
<!-- ## Cite this work -->
212+
-->
180213

181-
<!-- If you use this work, we encourage you to cite our paper, and the folktables dataset [[1]](#1). -->
182214

183-
<!-- ``` -->
184-
<!-- @article{ding2021retiring, -->
185-
<!-- title={Retiring Adult: New Datasets for Fair Machine Learning}, -->
186-
<!-- author={Ding, Frances and Hardt, Moritz and Miller, John and Schmidt, Ludwig}, -->
187-
<!-- journal={Advances in Neural Information Processing Systems}, -->
188-
<!-- volume={34}, -->
189-
<!-- year={2021} -->
190-
<!-- } -->
191-
<!-- ``` -->
192215

193216
## Future work
194217

195218
- Add more algorithms
196219
- Add more examples from different fields where constrained training of DNNs is employed
197-
- Migrate the benchmark to the new API
198220

199221
## References
200222

0 commit comments

Comments
 (0)