Skip to content

Commit 93fa8c3

Browse files
committed
Improve type annotations
1 parent c2de50a commit 93fa8c3

38 files changed

Lines changed: 156 additions & 111 deletions

examples/31_boids.ipynb

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,15 @@
8787
"num_boids = 256\n",
8888
"dt = 0.01\n",
8989
"\n",
90+
"acceleration_max = jnp.inf\n",
91+
"acceleration_scale = 1.0\n",
92+
"perception = 0.1\n",
93+
"separation_distance = 0.025\n",
94+
"\n",
9095
"separation_weight = 4.5\n",
9196
"alignment_weight = 0.65\n",
9297
"cohesion_weight = 0.75\n",
93-
"perception = 0.1\n",
94-
"separation_distance = 0.025\n",
95-
"acceleration_scale = 1.0\n",
9698
"noise_scale = 0.1\n",
97-
"acceleration_max = jnp.inf\n",
9899
"\n",
99100
"key = jax.random.key(seed)\n",
100101
"rngs = nnx.Rngs(seed)"
@@ -114,14 +115,14 @@
114115
"outputs": [],
115116
"source": [
116117
"boid_policy = BoidPolicy(\n",
118+
"\tacceleration_max=acceleration_max,\n",
119+
"\tacceleration_scale=acceleration_scale,\n",
120+
"\tperception=perception,\n",
121+
"\tseparation_distance=separation_distance,\n",
117122
"\tseparation_weight=separation_weight,\n",
118123
"\talignment_weight=alignment_weight,\n",
119124
"\tcohesion_weight=cohesion_weight,\n",
120-
"\tperception=perception,\n",
121-
"\tseparation_distance=separation_distance,\n",
122-
"\tacceleration_scale=acceleration_scale,\n",
123125
"\tnoise_scale=noise_scale,\n",
124-
"\tacceleration_max=acceleration_max,\n",
125126
"\trngs=rngs,\n",
126127
")\n",
127128
"\n",
@@ -226,20 +227,21 @@
226227
"\n",
227228
"\tdef __init__(\n",
228229
"\t\tself,\n",
229-
"\t\t*,\n",
230230
"\t\tnum_neighbors: int = 16, # Number of neighbors to consider\n",
231-
"\t\tperception: float = 0.1, # Perception radius\n",
232231
"\t\thidden_features: int = 8, # Hidden layer size from reference\n",
232+
"\t\t*,\n",
233+
"\t\tacceleration_max: float = jnp.inf,\n",
233234
"\t\tacceleration_scale: float = 10.0, # Scaling factor from reference\n",
235+
"\t\tperception: float = 0.1, # Perception radius\n",
234236
"\t\trngs: nnx.Rngs,\n",
235237
"\t):\n",
236-
"\t\t\"\"\"Initialize the boid policy.\"\"\"\n",
237-
"\t\tself.rngs = rngs\n",
238+
"\t\t\"\"\"Initialize boid policy.\"\"\"\n",
238239
"\t\tself.num_neighbors = num_neighbors\n",
239-
"\t\tself.perception = perception\n",
240+
"\t\tself.acceleration_max = acceleration_max\n",
240241
"\t\tself.acceleration_scale = acceleration_scale\n",
242+
"\t\tself.perception = perception\n",
243+
"\t\tself.rngs = rngs\n",
241244
"\n",
242-
"\t\t# Define the neural network layers similar to BoidNetwork\n",
243245
"\t\tself.dense1 = nnx.Linear(4, hidden_features, rngs=rngs)\n",
244246
"\t\tself.dense2 = nnx.Linear(hidden_features, hidden_features, rngs=rngs)\n",
245247
"\t\tself.dense3 = nnx.Linear(hidden_features, hidden_features, rngs=rngs)\n",
@@ -269,6 +271,11 @@
269271
"\n",
270272
"\t\treturn global2local, local2global, global2local_rot, local2global_rot\n",
271273
"\n",
274+
"\tdef _clip_by_norm(self, vector: jax.Array, max_val: float) -> jax.Array:\n",
275+
"\t\t\"\"\"Limit the magnitude of a vector.\"\"\"\n",
276+
"\t\tnorm = jnp.linalg.norm(vector)\n",
277+
"\t\treturn jnp.where(norm > max_val, vector * max_val / norm, vector)\n",
278+
"\n",
272279
"\tdef __call__(self, state: BoidsState, boid_idx: int) -> jax.Array:\n",
273280
"\t\t\"\"\"Compute acceleration for a boid based on its neighbors.\n",
274281
"\n",
@@ -341,9 +348,12 @@
341348
"\n",
342349
"\t\t# Transform back to global frame\n",
343350
"\t\tdv_hom = jnp.concatenate([dv_local, jnp.zeros(1)], axis=-1) # 3D homogeneous\n",
344-
"\t\tdv_global = (l2gr @ dv_hom[:, None])[:2, 0] # Back to 2D global coords\n",
351+
"\t\tacceleration = (l2gr @ dv_hom[:, None])[:2, 0] # Back to 2D global coords\n",
352+
"\n",
353+
"\t\t# Limit acceleration\n",
354+
"\t\tacceleration = self._clip_by_norm(acceleration, self.acceleration_max)\n",
345355
"\n",
346-
"\t\treturn dv_global"
356+
"\t\treturn acceleration"
347357
]
348358
},
349359
{

src/cax/core/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""Core Module."""
22

3-
from .cs import ComplexSystem
3+
from .cs import ComplexSystem, Input, State
44

5-
__all__ = ["ComplexSystem"]
5+
__all__ = [
6+
"ComplexSystem",
7+
"Input",
8+
"State",
9+
]

src/cax/core/perceive/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
from .conv_perceive import ConvPerceive
1414
from .kernels import grad2_kernel, grad_kernel, identity_kernel, neighbors_kernel
1515
from .moore_perceive import MoorePerceive
16-
from .perceive import Perceive
16+
from .perceive import Perceive, Perception
1717
from .von_neumann_perceive import VonNeumannPerceive
1818

1919
__all__ = [
20-
"ConvPerceive",
21-
"MoorePerceive",
2220
"Perceive",
21+
"Perception",
22+
"MoorePerceive",
2323
"VonNeumannPerceive",
24+
"ConvPerceive",
2425
"identity_kernel",
2526
"neighbors_kernel",
2627
"grad_kernel",

src/cax/core/perceive/conv_perceive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from flax import nnx
66

7-
from cax.types import Perception, State
7+
from cax.core import State
88

9-
from .perceive import Perceive
9+
from .perceive import Perceive, Perception
1010

1111

1212
class ConvPerceive(Perceive):

src/cax/core/perceive/moore_perceive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import jax.numpy as jnp
66

7-
from cax.types import Perception, State
7+
from cax.core import State
88

9-
from .perceive import Perceive
9+
from .perceive import Perceive, Perception
1010

1111

1212
class MoorePerceive(Perceive):

src/cax/core/perceive/perceive.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
that process the current state of the environment or system and return a perception.
77
"""
88

9-
from flax import nnx
9+
from flax import nnx, struct
1010

11-
from cax.types import Perception, State
11+
from cax.core import State
12+
13+
14+
@struct.dataclass
15+
class Perception:
16+
"""Perception class."""
1217

1318

1419
class Perceive(nnx.Module):

src/cax/core/perceive/von_neumann_perceive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import jax.numpy as jnp
66

7-
from cax.types import Perception, State
7+
from cax.core import State
88

9-
from .perceive import Perceive
9+
from .perceive import Perceive, Perception
1010

1111

1212
class VonNeumannPerceive(Perceive):

src/cax/core/update/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,9 @@
1616
from .residual_update import ResidualUpdate
1717
from .update import Update
1818

19-
__all__ = ["Update", "MLPUpdate", "NCAUpdate", "ResidualUpdate"]
19+
__all__ = [
20+
"Update",
21+
"MLPUpdate",
22+
"ResidualUpdate",
23+
"NCAUpdate",
24+
]

src/cax/core/update/mlp_update.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from flax.nnx.nn import initializers
88
from flax.nnx.nn.linear import default_kernel_init
99

10-
from cax.core.update.update import Update
11-
from cax.types import Input, Perception, State
10+
from cax.core import Input, State
11+
from cax.core.perceive import Perception
12+
13+
from .update import Update
1214

1315

1416
class MLPUpdate(Update):

src/cax/core/update/nca_update.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from flax import nnx
77
from jax import Array
88

9-
from cax.core.update.residual_update import ResidualUpdate
10-
from cax.types import Input, Perception, State
9+
from cax.core import Input, State
10+
from cax.core.perceive import Perception
11+
12+
from .residual_update import ResidualUpdate
1113

1214

1315
class NCAUpdate(ResidualUpdate):

0 commit comments

Comments
 (0)