|
87 | 87 | "num_boids = 256\n", |
88 | 88 | "dt = 0.01\n", |
89 | 89 | "\n", |
| 90 | + "acceleration_max = jnp.inf\n", |
| 91 | + "acceleration_scale = 1.0\n", |
| 92 | + "perception = 0.1\n", |
| 93 | + "separation_distance = 0.025\n", |
| 94 | + "\n", |
90 | 95 | "separation_weight = 4.5\n", |
91 | 96 | "alignment_weight = 0.65\n", |
92 | 97 | "cohesion_weight = 0.75\n", |
93 | | - "perception = 0.1\n", |
94 | | - "separation_distance = 0.025\n", |
95 | | - "acceleration_scale = 1.0\n", |
96 | 98 | "noise_scale = 0.1\n", |
97 | | - "acceleration_max = jnp.inf\n", |
98 | 99 | "\n", |
99 | 100 | "key = jax.random.key(seed)\n", |
100 | 101 | "rngs = nnx.Rngs(seed)" |
|
114 | 115 | "outputs": [], |
115 | 116 | "source": [ |
116 | 117 | "boid_policy = BoidPolicy(\n", |
| 118 | + "\tacceleration_max=acceleration_max,\n", |
| 119 | + "\tacceleration_scale=acceleration_scale,\n", |
| 120 | + "\tperception=perception,\n", |
| 121 | + "\tseparation_distance=separation_distance,\n", |
117 | 122 | "\tseparation_weight=separation_weight,\n", |
118 | 123 | "\talignment_weight=alignment_weight,\n", |
119 | 124 | "\tcohesion_weight=cohesion_weight,\n", |
120 | | - "\tperception=perception,\n", |
121 | | - "\tseparation_distance=separation_distance,\n", |
122 | | - "\tacceleration_scale=acceleration_scale,\n", |
123 | 125 | "\tnoise_scale=noise_scale,\n", |
124 | | - "\tacceleration_max=acceleration_max,\n", |
125 | 126 | "\trngs=rngs,\n", |
126 | 127 | ")\n", |
127 | 128 | "\n", |
|
226 | 227 | "\n", |
227 | 228 | "\tdef __init__(\n", |
228 | 229 | "\t\tself,\n", |
229 | | - "\t\t*,\n", |
230 | 230 | "\t\tnum_neighbors: int = 16, # Number of neighbors to consider\n", |
231 | | - "\t\tperception: float = 0.1, # Perception radius\n", |
232 | 231 | "\t\thidden_features: int = 8, # Hidden layer size from reference\n", |
| 232 | + "\t\t*,\n", |
| 233 | + "\t\tacceleration_max: float = jnp.inf,\n", |
233 | 234 | "\t\tacceleration_scale: float = 10.0, # Scaling factor from reference\n", |
| 235 | + "\t\tperception: float = 0.1, # Perception radius\n", |
234 | 236 | "\t\trngs: nnx.Rngs,\n", |
235 | 237 | "\t):\n", |
236 | | - "\t\t\"\"\"Initialize the boid policy.\"\"\"\n", |
237 | | - "\t\tself.rngs = rngs\n", |
| 238 | + "\t\t\"\"\"Initialize boid policy.\"\"\"\n", |
238 | 239 | "\t\tself.num_neighbors = num_neighbors\n", |
239 | | - "\t\tself.perception = perception\n", |
| 240 | + "\t\tself.acceleration_max = acceleration_max\n", |
240 | 241 | "\t\tself.acceleration_scale = acceleration_scale\n", |
| 242 | + "\t\tself.perception = perception\n", |
| 243 | + "\t\tself.rngs = rngs\n", |
241 | 244 | "\n", |
242 | | - "\t\t# Define the neural network layers similar to BoidNetwork\n", |
243 | 245 | "\t\tself.dense1 = nnx.Linear(4, hidden_features, rngs=rngs)\n", |
244 | 246 | "\t\tself.dense2 = nnx.Linear(hidden_features, hidden_features, rngs=rngs)\n", |
245 | 247 | "\t\tself.dense3 = nnx.Linear(hidden_features, hidden_features, rngs=rngs)\n", |
|
269 | 271 | "\n", |
270 | 272 | "\t\treturn global2local, local2global, global2local_rot, local2global_rot\n", |
271 | 273 | "\n", |
| 274 | + "\tdef _clip_by_norm(self, vector: jax.Array, max_val: float) -> jax.Array:\n", |
| 275 | + "\t\t\"\"\"Limit the magnitude of a vector.\"\"\"\n", |
| 276 | + "\t\tnorm = jnp.linalg.norm(vector)\n", |
| 277 | + "\t\treturn jnp.where(norm > max_val, vector * max_val / norm, vector)\n", |
| 278 | + "\n", |
272 | 279 | "\tdef __call__(self, state: BoidsState, boid_idx: int) -> jax.Array:\n", |
273 | 280 | "\t\t\"\"\"Compute acceleration for a boid based on its neighbors.\n", |
274 | 281 | "\n", |
|
341 | 348 | "\n", |
342 | 349 | "\t\t# Transform back to global frame\n", |
343 | 350 | "\t\tdv_hom = jnp.concatenate([dv_local, jnp.zeros(1)], axis=-1) # 3D homogeneous\n", |
344 | | - "\t\tdv_global = (l2gr @ dv_hom[:, None])[:2, 0] # Back to 2D global coords\n", |
| 351 | + "\t\tacceleration = (l2gr @ dv_hom[:, None])[:2, 0] # Back to 2D global coords\n", |
| 352 | + "\n", |
| 353 | + "\t\t# Limit acceleration\n", |
| 354 | + "\t\tacceleration = self._clip_by_norm(acceleration, self.acceleration_max)\n", |
345 | 355 | "\n", |
346 | | - "\t\treturn dv_global" |
| 356 | + "\t\treturn acceleration" |
347 | 357 | ] |
348 | 358 | }, |
349 | 359 | { |
|
0 commit comments