Skip to content

Commit 1c60cab

Browse files
committed
test publish
1 parent 8a1fcca commit 1c60cab

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

.github/workflows/publish.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: Publish
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
tags:
8+
- '0.0.*'
9+
paths:
10+
- transformer_flows/**
11+
- pyproject.toml
12+
13+
jobs:
14+
pypi:
15+
name: Publish to PyPI
16+
runs-on: ubuntu-latest
17+
# Environment and permissions trusted publishing.
18+
environment:
19+
# Create this environment in the GitHub repository under Settings -> Environments
20+
name: release
21+
permissions:
22+
id-token: write
23+
steps:
24+
- uses: actions/checkout@v4
25+
- uses: astral-sh/setup-uv@v3
26+
- run: uv build
27+
# Check that basic features work and we didn't miss to include crucial files
28+
- name: Smoke test (wheel)
29+
run: uv run --isolated --no-project -p 3.13 --with dist/*.whl tests/smoke_test.py
30+
- name: Smoke test (source distribution)
31+
run: uv run --isolated --no-project -p 3.13 --with dist/*.tar.gz tests/smoke_test.py
32+
- run: uv publish --trusted-publishing always
33+
34+
github-release:
35+
name: GitHub Release
36+
needs:
37+
- pypi
38+
runs-on: ubuntu-latest
39+
40+
permissions:
41+
contents: write # IMPORTANT: mandatory for making GitHub Releases
42+
id-token: write # IMPORTANT: mandatory for sigstore

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ data/
1010
__pycache__/
1111
tests/
1212
grfs.py
13-
guidance.py
13+
guidance.py
14+
test.yml

transformer_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from attention import MultiheadAttention, self_attention
2626

2727

28-
typecheck = lambda x: x #jaxtyped(typechecker=typechecker)
28+
typecheck = jaxtyped(typechecker=typechecker)
2929

3030

3131
MetricsDict = dict[
@@ -1652,7 +1652,7 @@ def train(
16521652
# Sharding: data and model
16531653
sharding: Optional[NamedSharding] = None,
16541654
replicated_sharding: Optional[NamedSharding] = None,
1655-
save_fn: Callable[[Optional[str], TransformerFlow], None]
1655+
save_fn: Callable[[Optional[str], TransformerFlow], None] = None
16561656
) -> TransformerFlow:
16571657

16581658
print("n_params={:.3E}".format(count_parameters(model)))

0 commit comments

Comments
 (0)