Skip to content

Commit 56d45fc

Browse files
committed
Increment version
1 parent 348b97e commit 56d45fc

File tree

4 files changed

+44
-27
lines changed

4 files changed

+44
-27
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ repos:
1010
- id: check-yaml
1111
- id: debug-statements
1212
- id: end-of-file-fixer
13-
- id: no-commit-to-branch
14-
args: [--branch, main]
1513
- id: requirements-txt-fixer
1614
- id: trailing-whitespace
1715

@@ -20,7 +18,7 @@ repos:
2018
hooks:
2119
- id: mypy
2220
args: ["--ignore-missing-imports"]
23-
files: "(fll)"
21+
files: "(blaxbird)"
2422

2523
- repo: https://github.com/astral-sh/ruff-pre-commit
2624
rev: v0.3.0

examples/cifar10_flow_matching/main.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import os
33

4+
import dataloader
45
import jax
56
import matplotlib.pyplot as plt
67
import numpy as np
@@ -12,7 +13,6 @@
1213
from jax.experimental import mesh_utils
1314

1415
import blaxbird
15-
import dataloader
1616
from blaxbird import get_default_checkpointer, train_fn
1717
from blaxbird.experimental import rfm
1818

@@ -38,19 +38,19 @@ def visualize_hook(sample_fn, val_iter, hook_every_n_steps, log_to_wandb):
3838

3939
def convert_batch_to_image_grid(image_batch):
4040
reshaped = (
41-
image_batch.reshape(n_row, n_col, *img_size)
42-
.transpose([0, 2, 1, 3, 4])
43-
.reshape(n_row * img_size[0], n_col * img_size[1], img_size[2])
41+
image_batch.reshape(n_row, n_col, *img_size)
42+
.transpose([0, 2, 1, 3, 4])
43+
.reshape(n_row * img_size[0], n_col * img_size[1], img_size[2])
4444
)
4545
return (reshaped + 1.0) / 2.0
4646

4747
def plot(images):
4848
fig = plt.figure(figsize=(16, 6))
4949
ax = fig.add_subplot(1, 1, 1)
5050
ax.imshow(
51-
images,
52-
interpolation="nearest",
53-
cmap="gray",
51+
images,
52+
interpolation="nearest",
53+
cmap="gray",
5454
)
5555
plt.axis("off")
5656
plt.tight_layout()
@@ -61,26 +61,31 @@ def fn(step, *, model, **kwargs):
6161
return
6262
all_samples = []
6363
for i, batch in enumerate(val_iter):
64-
samples = sample_fn(model, jr.fold_in(jr.key(step), i), sample_shape=batch["inputs"].shape)
65-
all_samples.append(samples)
66-
if len(all_samples) * all_samples[0].shape[0] >= n_row * n_col:
67-
break
68-
all_samples = np.concatenate(all_samples, axis=0)[:(n_row * n_col)]
64+
samples = sample_fn(
65+
model, jr.fold_in(jr.key(step), i), sample_shape=batch["inputs"].shape
66+
)
67+
all_samples.append(samples)
68+
if len(all_samples) * all_samples[0].shape[0] >= n_row * n_col:
69+
break
70+
all_samples = np.concatenate(all_samples, axis=0)[: (n_row * n_col)]
6971
all_samples = convert_batch_to_image_grid(all_samples)
7072
fig = plot(all_samples)
7173
if jax.process_index() == 0 and log_to_wandb:
72-
wandb.log({"images": wandb.Image(fig)}, step=step)
74+
wandb.log({"images": wandb.Image(fig)}, step=step)
7375

7476
return fn
7577

7678

77-
def get_hooks(sample_fn, val_itr, hook_every_n_steps, log_to_wandb ):
79+
def get_hooks(sample_fn, val_itr, hook_every_n_steps, log_to_wandb):
7880
return [visualize_hook(sample_fn, val_itr, hook_every_n_steps, log_to_wandb)]
7981

8082

8183
def get_train_and_val_itrs(rng_key, outfolder):
8284
return dataloader.data_loaders(
83-
rng_key, outfolder, split=["train[:90%]", "train[90%:]"], shuffle=[True, False],
85+
rng_key,
86+
outfolder,
87+
split=["train[:90%]", "train[90%:]"],
88+
shuffle=[True, False],
8489
)
8590

8691

@@ -92,14 +97,19 @@ def run(n_steps, eval_every_n_steps, n_eval_batches, dit_type, log_to_wandb):
9297
jr.key(0), os.path.join(outfolder, "data")
9398
)
9499

95-
model = getattr(blaxbird.experimental, dit_type)(image_size=(32, 32, 3), rngs=nnx.rnglib.Rngs(jr.key(1)))
100+
model = getattr(blaxbird.experimental, dit_type)(
101+
image_size=(32, 32, 3), rngs=nnx.rnglib.Rngs(jr.key(1))
102+
)
96103
train_step, val_step, sample_fn = rfm()
97104
optimizer = get_optimizer(model)
98105

99106
save_fn, _, restore_last_fn = get_default_checkpointer(
100-
os.path.join(outfolder, "checkpoints"), save_every_n_steps=eval_every_n_steps
107+
os.path.join(outfolder, "checkpoints"),
108+
save_every_n_steps=eval_every_n_steps,
101109
)
102-
hooks = get_hooks(sample_fn, val_itr, eval_every_n_steps, log_to_wandb) + [save_fn]
110+
hooks = get_hooks(sample_fn, val_itr, eval_every_n_steps, log_to_wandb) + [
111+
save_fn
112+
]
103113

104114
model_sharding, data_sharding = get_sharding()
105115
model, optimizer = restore_last_fn(model, optimizer)
@@ -121,7 +131,15 @@ def run(n_steps, eval_every_n_steps, n_eval_batches, dit_type, log_to_wandb):
121131
parser.add_argument("--n-steps", type=int, default=1_000)
122132
parser.add_argument("--eval-every-n-steps", type=int, default=50)
123133
parser.add_argument("--n-eval-batches", type=int, default=10)
124-
parser.add_argument("--dit", type=str, choices=["SmallDiT", "BaseDiT"], default="SmallDiT")
134+
parser.add_argument(
135+
"--dit", type=str, choices=["SmallDiT", "BaseDiT"], default="SmallDiT"
136+
)
125137
parser.add_argument("--log-to-wandb", action="store_true")
126138
args = parser.parse_args()
127-
run(args.n_steps, args.eval_every_n_steps, args.n_eval_batches, args.dit, args.log_to_wandb)
139+
run(
140+
args.n_steps,
141+
args.eval_every_n_steps,
142+
args.n_eval_batches,
143+
args.dit,
144+
args.log_to_wandb,
145+
)

examples/mnist_classification/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
from flax import nnx
99
from jax import random as jr
1010
from jax.experimental import mesh_utils
11+
from model import CNN, train_step, val_step
1112

1213
from blaxbird import get_default_checkpointer, train_fn
1314

14-
from model import CNN, train_step, val_step
15-
1615

1716
def get_optimizer(model, lr=1e-4):
1817
tx = optax.adamw(lr)
@@ -83,7 +82,8 @@ def run(n_steps, eval_every_n_steps, n_eval_batches):
8382
optimizer = get_optimizer(model)
8483

8584
save_fn, _, restore_last_fn = get_default_checkpointer(
86-
os.path.join(outfolder, "checkpoints"), save_every_n_steps=eval_every_n_steps
85+
os.path.join(outfolder, "checkpoints"),
86+
save_every_n_steps=eval_every_n_steps,
8787
)
8888
hooks = get_hooks(val_itr, eval_every_n_steps) + [save_fn]
8989

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,15 @@ addopts = "-v --doctest-modules --cov=./blaxbird --cov-report=xml"
6868
[tool.ruff]
6969
indent-width = 2
7070
line-length = 80
71-
exclude = ["*_test.py", "docs/**", "examples/**"]
7271

7372
[tool.ruff.lint]
7473
select = ["D", "E", "F", "W", "I001"]
7574
extend-select = [
7675
"UP", "I", "PL", "S"
7776
]
7877
ignore = ["S101", "ANN101", "PLR2044", "PLR0913"]
78+
exclude = ["*_test.py", "docs/**", "examples/**"]
79+
7980

8081
[tool.ruff.lint.pydocstyle]
8182
convention= 'google'

0 commit comments

Comments
 (0)