Skip to content

Commit c0f6dbf

Browse files
committed
Add GitHub actions
Add a GitHub action to compute code coverage
1 parent a5debe4 commit c0f6dbf

6 files changed

Lines changed: 113 additions & 32 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@ jobs:
7777
- name: Run tests
7878
run: |
7979
hatch run test:test
80+
- name: Upload coverage reports to Codecov
81+
uses: codecov/codecov-action@v3
82+
env:
83+
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

README.md

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept)
44
[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml)
5+
[![codecov](https://codecov.io/gh/dirmeier/surjectors/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/dirmeier/surjectors)
56
[![version](https://img.shields.io/pypi/v/surjectors.svg?colorB=black&style=flat)](https://pypi.org/project/surjectors/)
67

78
> Surjection layers for density estimation with normalizing flows
@@ -19,13 +20,34 @@ Surjectors makes use of
1920
- Optax for gradient-based optimization,
2021
- JAX for autodiff and XLA computation.
2122

22-
## Documentation
23+
## Examples
2324

24-
Documentation can be found [here](https://surjectors.readthedocs.io/en/latest/).
25+
You can, for instance, construct a simple normalizing flow like this:
26+
27+
```python
28+
import distrax
29+
from jax import numpy as jnp
30+
from surjectors import Slice, LULinear, Chain
31+
from surjectors import TransformedDistribution
32+
from surjectors.nn import make_mlp
33+
34+
def decoder_fn(n_dim):
35+
def _fn(z):
36+
params = make_mlp([32, 32, n_dim * 2])(z)
37+
means, log_scales = jnp.split(params, 2, -1)
38+
return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
39+
return _fn
40+
41+
base_distribution = distrax.Normal(jnp.zeros(5), jnp.ones(5))
42+
transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
43+
pushforward = TransformedDistribution(base_distribution, transform)
44+
```
2545

26-
## Examples
46+
More self-contained examples can be found in [examples](https://github.com/dirmeier/surjectors/tree/main/examples).
47+
48+
## Documentation
2749

28-
You can find several self-contained examples on how to use the algorithms in `examples`.
50+
Documentation can be found [here](https://surjectors.readthedocs.io/en/latest/).
2951

3052
## Installation
3153

@@ -47,11 +69,11 @@ pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
4769
## Contributing
4870

4971
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
50-
`"good first issue" <https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22>`_.
72+
[good first issue](https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).
5173

5274
In order to contribute:
5375

54-
1) Clone Surjectors and install `hatch` via `pip install hatch`,
76+
1) Clone `Surjectors` and install `hatch` via `pip install hatch`,
5577
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
5678
3) implement your contribution and ideally a test case,
5779
4) test it by calling `hatch run test` on the (Unix) command line,

docs/index.rst

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,27 @@ Surjectors makes use of
1414
- Optax for gradient-based optimization,
1515
- JAX for autodiff and XLA computation.
1616

17-
Example usage
18-
-------------
17+
Example
18+
-------
1919

2020
You can, for instance, construct a simple normalizing flow like this:
2121

2222
>>> import distrax
23-
>>> from jax import random as jr, numpy as jnp
23+
>>> from jax import numpy as jnp
2424
>>> from surjectors import Slice, LULinear, Chain
2525
>>> from surjectors import TransformedDistribution
26+
>>> from surjectors.nn import make_mlp
2627
>>>
2728
>>> def decoder_fn(n_dim):
2829
>>> def _fn(z):
29-
>>> params = make_mlp([4, 4, n_dim * 2])(z)
30-
>>> mu, log_scale = jnp.split(params, 2, -1)
31-
>>> return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
30+
>>> params = make_mlp([32, 32, n_dim * 2])(z)
31+
>>> means, log_scales = jnp.split(params, 2, -1)
32+
>>> return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
3233
>>> return _fn
3334
>>>
34-
>>> base_distribution = distrax.Normal(jno.zeros(5), jnp.ones(1))
35-
>>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
36-
>>> pushforward = TransformedDistribution(base_distribution, flow)
35+
>>> base_distribution = distrax.Normal(jnp.zeros(5), jnp.ones(1))
36+
>>> transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
37+
>>> pushforward = TransformedDistribution(base_distribution, transform)
3738

3839
The flow is constructed using three objects: a base distribution, a transformation, and a transformed distribution.
3940

@@ -64,7 +65,7 @@ Contributions in the form of pull requests are more than welcome. A good way to
6465

6566
In order to contribute:
6667

67-
1) Clone Surjectors and install :code:`hatch` via :code:`pip install hatch`,
68+
1) Clone :code:`Surjectors` and install :code:`hatch` via :code:`pip install hatch`,
6869
2) create a new branch locally :code:`git checkout -b feature/my-new-feature` or :code:`git checkout -b issue/fixes-bug`,
6970
3) implement your contribution and ideally a test case,
7071
4) test it by calling :code:`hatch run test` on the (Unix) command line,

docs/notebooks/introduction.ipynb

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
{
44
"cell_type": "markdown",
55
"id": "4f6f4229-6b15-4e2b-89af-8957708479d7",
6-
"metadata": {},
6+
"metadata": {
7+
"pycharm": {
8+
"name": "#%% md\n"
9+
}
10+
},
711
"source": [
812
"# Constructing normalizing flows\n",
913
"\n",
@@ -15,7 +19,10 @@
1519
"execution_count": 4,
1620
"id": "9497e202-3f4e-4602-9e90-669545f18816",
1721
"metadata": {
18-
"tags": []
22+
"tags": [],
23+
"pycharm": {
24+
"name": "#%%\n"
25+
}
1926
},
2027
"outputs": [],
2128
"source": [
@@ -28,7 +35,11 @@
2835
{
2936
"cell_type": "markdown",
3037
"id": "bfd96e3f-ef84-454f-a611-0cdc0a629d2d",
31-
"metadata": {},
38+
"metadata": {
39+
"pycharm": {
40+
"name": "#%% md\n"
41+
}
42+
},
3243
"source": [
3344
"## How to construct a Haiku module\n",
3445
"\n",
@@ -40,7 +51,10 @@
4051
"execution_count": 30,
4152
"id": "c9db6163-6cb8-4968-8e3d-90d35a5094cf",
4253
"metadata": {
43-
"tags": []
54+
"tags": [],
55+
"pycharm": {
56+
"name": "#%%\n"
57+
}
4458
},
4559
"outputs": [],
4660
"source": [
@@ -53,7 +67,10 @@
5367
"execution_count": 61,
5468
"id": "f78c2fc0-f4f6-476d-8b15-12b4a874ce1d",
5569
"metadata": {
56-
"tags": []
70+
"tags": [],
71+
"pycharm": {
72+
"name": "#%%\n"
73+
}
5774
},
5875
"outputs": [],
5976
"source": [
@@ -88,7 +105,10 @@
88105
"cell_type": "markdown",
89106
"id": "b41e89a1-6df4-4069-a606-4f2412b9dc6a",
90107
"metadata": {
91-
"tags": []
108+
"tags": [],
109+
"pycharm": {
110+
"name": "#%% md\n"
111+
}
92112
},
93113
"source": [
94114
"Constructing a Haiku module needs to be done within a `hk.transform` block. This can either be done by providing a function like here and an object. In our case we are using `hk.transform` on `pushforward(**kwargs)` which calls\n",
@@ -98,7 +118,11 @@
98118
{
99119
"cell_type": "markdown",
100120
"id": "689f867c-2259-4713-81b6-5352837cb342",
101-
"metadata": {},
121+
"metadata": {
122+
"pycharm": {
123+
"name": "#%% md\n"
124+
}
125+
},
102126
"source": [
103127
"We can now initialize the flow. Let's define a random data set first and then initialize the parameters."
104128
]
@@ -108,7 +132,10 @@
108132
"execution_count": 65,
109133
"id": "cf91f2ce-6ba9-438c-95a7-ec9dcfbfea17",
110134
"metadata": {
111-
"tags": []
135+
"tags": [],
136+
"pycharm": {
137+
"name": "#%%\n"
138+
}
112139
},
113140
"outputs": [
114141
{
@@ -150,7 +177,11 @@
150177
{
151178
"cell_type": "markdown",
152179
"id": "4c8f8b8e-4e81-4b54-a1c4-72027db33e1c",
153-
"metadata": {},
180+
"metadata": {
181+
"pycharm": {
182+
"name": "#%% md\n"
183+
}
184+
},
154185
"source": [
155186
"The only trainable paramaters that are flow defines are the weights of the MLP. The MLP is used to compute the conditional probability density inside the `decoder_fn` function. \n",
156187
"The `Slice` surjector itself doesn't have paramters."
@@ -159,7 +190,11 @@
159190
{
160191
"cell_type": "markdown",
161192
"id": "84d8368a-2c75-4e72-a3b4-45f4cfc71fbf",
162-
"metadata": {},
193+
"metadata": {
194+
"pycharm": {
195+
"name": "#%% md\n"
196+
}
197+
},
163198
"source": [
164199
"We can now test the flow. Let's sample some data first."
165200
]
@@ -169,7 +204,10 @@
169204
"execution_count": 71,
170205
"id": "77ae282d-9e30-49c9-98b1-6f9494b34a21",
171206
"metadata": {
172-
"tags": []
207+
"tags": [],
208+
"pycharm": {
209+
"name": "#%%\n"
210+
}
173211
},
174212
"outputs": [
175213
{
@@ -196,7 +234,11 @@
196234
{
197235
"cell_type": "markdown",
198236
"id": "04fd204d-f0c2-468b-a64d-ec292b071f34",
199-
"metadata": {},
237+
"metadata": {
238+
"pycharm": {
239+
"name": "#%% md\n"
240+
}
241+
},
200242
"source": [
201243
"As mentioned above, in order to dispatch to a method, we just provide a keyword argument. In this case this is `method='sample'`. Computing the log probability of the data can be done, by changing the method argument to `log_prob`."
202244
]
@@ -206,7 +248,10 @@
206248
"execution_count": 72,
207249
"id": "2190c369-4242-45ec-9ac9-b9fa5bccc1dd",
208250
"metadata": {
209-
"tags": []
251+
"tags": [],
252+
"pycharm": {
253+
"name": "#%%\n"
254+
}
210255
},
211256
"outputs": [
212257
{
@@ -227,7 +272,11 @@
227272
{
228273
"cell_type": "markdown",
229274
"id": "8c046ffd-a2b9-4348-b8f3-85e978b7998f",
230-
"metadata": {},
275+
"metadata": {
276+
"pycharm": {
277+
"name": "#%% md\n"
278+
}
279+
},
231280
"source": [
232281
"## How to construct `TransformedDistribution` objects\n",
233282
"\n",
@@ -238,7 +287,11 @@
238287
"cell_type": "code",
239288
"execution_count": null,
240289
"id": "10c16d0c-0158-4845-b192-dd97967d7eb6",
241-
"metadata": {},
290+
"metadata": {
291+
"pycharm": {
292+
"name": "#%%\n"
293+
}
294+
},
242295
"outputs": [],
243296
"source": []
244297
}

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ seaborn
77
sphinx
88
sphinx-autobuild
99
sphinx-book-theme
10+
sphinx-copybutton
1011
sphinx-math-dollar
1112
sphinx_autodoc_typehints
1213
sphinx_design

surjectors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
surjectors: Surjection layers for density estimation with normalizing flows
33
"""
44

5-
__version__ = "0.2.4"
5+
__version__ = "0.3.0"
66

77
from surjectors._src.bijectors.lu_linear import LULinear
88
from surjectors._src.bijectors.masked_autoregressive import MaskedAutoregressive

0 commit comments

Comments
 (0)