44 "metadata" : {
55 "colab" : {
66 "provenance" : [],
7- "authorship_tag" : " ABX9TyO/VV28tLyXQmzWWnHKkzuo " ,
7+ "authorship_tag" : " ABX9TyOKcfrMg1MLg77aTV44y/4b " ,
88 "include_colab_link" : true
99 },
1010 "kernelspec" : {
44944494 "name" : " stderr" ,
44954495 "text" : [
44964496 " \n " ,
4497- " 99 %|█████████▊| 994 /1009 [29:52<00:27 , 1.80s /it]\u001b [A"
4497+ " 52 %|█████▏ | 523 /1009 [14:53<13:28 , 1.66s /it]\u001b [A"
44984498 ]
44994499 },
45004500 {
45014501 "output_type" : " stream" ,
45024502 "name" : " stdout" ,
45034503 "text" : [
4504- " loss: 1486.7937 <<< \n "
4504+ " loss: 1168.78 <<< \n "
45054505 ]
45064506 }
45074507 ]
45834583 },
45844584 "execution_count" : null ,
45854585 "outputs" : []
4586+ },
4587+ {
4588+ "cell_type" : " markdown" ,
4589+ "source" : [
4590+ " **ensemble**"
4591+ ],
4592+ "metadata" : {
4593+ "id" : " ka3_468K9xRZ"
4594+ }
4595+ },
4596+ {
4597+ "cell_type" : " code" ,
4598+ "source" : [
4599+ " #✅\n " ,
4600+ " !python -m pip install -q -U flax\n " ,
4601+ " import optax\n " ,
4602+ " from flax.training import train_state\n " ,
4603+ " import jax.numpy as jnp\n " ,
4604+ " import jax\n " ,
4605+ " \n " ,
4606+ " @functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))\n " ,
4607+ " def Create_train_state(r_key, shape, learning_rate ):\n " ,
4608+ " print(shape)\n " ,
4609+ " model = MLPModel()\n " ,
4610+ " variables = model.init(r_key, jnp.ones(shape)) \n " ,
4611+ " optimizer = optax.adam(learning_rate) \n " ,
4612+ " return train_state.TrainState.create(\n " ,
4613+ " apply_fn = model.apply,\n " ,
4614+ " tx=optimizer,\n " ,
4615+ " params=variables['params']\n " ,
4616+ " )\n " ,
4617+ " \n " ,
4618+ " learning_rate = 1e-4\n " ,
4619+ " batch_size_no = 64\n " ,
4620+ " \n " ,
4621+ " model = MLPModel() # Instantiate the Model"
4622+ ],
4623+ "metadata" : {
4624+ "id" : " Aaat0R0q9Z7F"
4625+ },
4626+ "execution_count" : null ,
4627+ "outputs" : []
4628+ },
4629+ {
4630+ "cell_type" : " code" ,
4631+ "source" : [
4632+ " @functools.partial(jax.pmap, axis_name='ensemble')\n " ,
4633+ " def apply_model(state, batch: jnp.asarray):\n " ,
4634+ " image, label = batch\n " ,
4635+ " def loss_fn(params):\n " ,
4636+ " logits = MLPModel().apply({'params': params}, image)\n " ,
4637+ " loss = image_difference_loss(logits, label);\n " ,
4638+ " return loss, logits\n " ,
4639+ " \n " ,
4640+ " grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n " ,
4641+ " (loss, logits), grads = grad_fn(state.params)\n " ,
4642+ " return grads, loss\n " ,
4643+ " \n " ,
4644+ " @jax.pmap\n " ,
4645+ " def update_model(state, grads):\n " ,
4646+ " return state.apply_gradients(grads=grads)"
4647+ ],
4648+ "metadata" : {
4649+ "id" : " QyGP4Fmf-q7q"
4650+ },
4651+ "execution_count" : null ,
4652+ "outputs" : []
4653+ },
4654+ {
4655+ "cell_type" : " code" ,
4656+ "source" : [
4657+ " def train_epoch(state, train_ds, batch_size, rng):\n " ,
4658+ " train_ds_size = len(train_ds['image'])\n " ,
4659+ " steps_per_epoch = train_ds_size // batch_size\n " ,
4660+ " \n " ,
4661+ " perms = jax.random.permutation(rng, len(train_ds['image']))\n " ,
4662+ " perms = perms[:steps_per_epoch * batch_size]\n " ,
4663+ " perms = perms.reshape((steps_per_epoch, batch_size))\n " ,
4664+ " \n " ,
4665+ " epoch_loss = []\n " ,
4666+ " \n " ,
4667+ " for perm in perms:\n " ,
4668+ " batch_images = jax_utils.replicate(train_ds['image'][perm, ...])\n " ,
4669+ " batch_labels = jax_utils.replicate(train_ds['label'][perm, ...])\n " ,
4670+ " grads, loss = apply_model(state, batch_images, batch_labels)\n " ,
4671+ " state = update_model(state, grads)\n " ,
4672+ " epoch_loss.append(jax_utils.unreplicate(loss))\n " ,
4673+ " train_loss = np.mean(epoch_loss)\n " ,
4674+ " return state, train_loss"
4675+ ],
4676+ "metadata" : {
4677+ "id" : " QQUj3Y3LA9A1"
4678+ },
4679+ "execution_count" : null ,
4680+ "outputs" : []
4681+ },
4682+ {
4683+ "cell_type" : " code" ,
4684+ "source" : [
4685+ " train_ds, test_ds = get_datasets()\n " ,
4686+ " test_ds = jax_utils.replicate(test_ds)\n " ,
4687+ " rng = jax.random.PRNGKey(0)\n " ,
4688+ " \n " ,
4689+ " rng, init_rng = jax.random.split(rng)\n " ,
4690+ " \n " ,
4691+ " HxW, Channels = next(batches)[0].shape\n " ,
4692+ " state = create_train_state(jax.random.split(init_rng, jax.device_count()),(HxW, Channels),learning_rate)\n " ,
4693+ " \n " ,
4694+ " for epoch in range(1, num_epochs + 1):\n " ,
4695+ " rng, input_rng = jax.random.split(rng)\n " ,
4696+ " state, train_loss = train_epoch(state, train_ds, batch_size, input_rng)\n " ,
4697+ " \n " ,
4698+ " # _, test_loss = jax_utils.unreplicate(apply_model(state, test_ds['image'], test_ds['label']))\n " ,
4699+ " \n " ,
4700+ " logging.info('epoch:% 3d, train_loss: %.4f ' % (epoch, train_loss))"
4701+ ],
4702+ "metadata" : {
4703+ "id" : " X-CttLscBnDQ"
4704+ },
4705+ "execution_count" : null ,
4706+ "outputs" : []
4707+ },
4708+ {
4709+ "cell_type" : " code" ,
4710+ "source" : [
4711+ " correct = total = 0\n " ,
4712+ " for batch in ds.as_numpy_iterator():\n " ,
4713+ " preds = flax.jax_utils.pad_shard_unpad(get_preds)(\n " ,
4714+ " vs, batch['image'], min_device_batch=per_device_batch_size)\n " ,
4715+ " total += len(batch['image'])\n " ,
4716+ " correct += (batch['label'] == preds.argmax(axis=-1)).sum()"
4717+ ],
4718+ "metadata" : {
4719+ "id" : " I_orMqbuD3LL"
4720+ },
4721+ "execution_count" : null ,
4722+ "outputs" : []
4723+ },
4724+ {
4725+ "cell_type" : " code" ,
4726+ "source" : [
4727+ " def eval_step(metrics, variables, batch):\n " ,
4728+ " print('retrigger compilation', {k: v.shape for k, v in batch.items()})\n " ,
4729+ " preds = model.apply(variables, batch['image'])\n " ,
4730+ " correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum()\n " ,
4731+ " total = batch['mask'].sum()\n " ,
4732+ " return dict(\n " ,
4733+ " correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'),\n " ,
4734+ " total=metrics['total'] + jax.lax.psum(total, axis_name='batch'),\n " ,
4735+ " )\n " ,
4736+ " \n " ,
4737+ " eval_step = jax.pmap(eval_step, axis_name='batch')\n " ,
4738+ " eval_step = flax.jax_utils.pad_shard_unpad(\n " ,
4739+ " eval_step, static_argnums=(0, 1), static_return=True)"
4740+ ],
4741+ "metadata" : {
4742+ "id" : " RxhJjRZLD5P-"
4743+ },
4744+ "execution_count" : null ,
4745+ "outputs" : []
45864746 }
45874747 ]
45884748}
0 commit comments