Skip to content

Commit 992bb5b

Browse files
committed
Created using Colaboratory
1 parent 0441144 commit 992bb5b

File tree

1 file changed

+163
-3
lines changed

1 file changed

+163
-3
lines changed

MLP_Image_Train_Inference_JAX.ipynb

Lines changed: 163 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"metadata": {
55
"colab": {
66
"provenance": [],
7-
"authorship_tag": "ABX9TyO/VV28tLyXQmzWWnHKkzuo",
7+
"authorship_tag": "ABX9TyOKcfrMg1MLg77aTV44y/4b",
88
"include_colab_link": true
99
},
1010
"kernelspec": {
@@ -4494,14 +4494,14 @@
44944494
"name": "stderr",
44954495
"text": [
44964496
"\n",
4497-
" 99%|█████████▊| 994/1009 [29:52<00:27, 1.80s/it]\u001b[A"
4497+
" 52%|█████▏ | 523/1009 [14:53<13:28, 1.66s/it]\u001b[A"
44984498
]
44994499
},
45004500
{
45014501
"output_type": "stream",
45024502
"name": "stdout",
45034503
"text": [
4504-
"loss: 1486.7937 <<< \n"
4504+
"loss: 1168.78 <<< \n"
45054505
]
45064506
}
45074507
]
@@ -4583,6 +4583,166 @@
45834583
},
45844584
"execution_count": null,
45854585
"outputs": []
4586+
},
4587+
{
4588+
"cell_type": "markdown",
4589+
"source": [
4590+
"**ensemble**"
4591+
],
4592+
"metadata": {
4593+
"id": "ka3_468K9xRZ"
4594+
}
4595+
},
4596+
{
4597+
"cell_type": "code",
4598+
"source": [
4599+
"#✅\n",
4600+
"!python -m pip install -q -U flax\n",
4601+
"import optax\n",
4602+
"from flax.training import train_state\n",
4603+
"import jax.numpy as jnp\n",
4604+
"import jax\n",
4605+
"\n",
4606+
"@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))\n",
4607+
"def Create_train_state(r_key, shape, learning_rate ):\n",
4608+
" print(shape)\n",
4609+
" model = MLPModel()\n",
4610+
" variables = model.init(r_key, jnp.ones(shape)) \n",
4611+
" optimizer = optax.adam(learning_rate) \n",
4612+
" return train_state.TrainState.create(\n",
4613+
" apply_fn = model.apply,\n",
4614+
" tx=optimizer,\n",
4615+
" params=variables['params']\n",
4616+
" )\n",
4617+
"\n",
4618+
"learning_rate = 1e-4\n",
4619+
"batch_size_no = 64\n",
4620+
"\n",
4621+
"model = MLPModel() # Instantiate the Model"
4622+
],
4623+
"metadata": {
4624+
"id": "Aaat0R0q9Z7F"
4625+
},
4626+
"execution_count": null,
4627+
"outputs": []
4628+
},
4629+
{
4630+
"cell_type": "code",
4631+
"source": [
4632+
"@functools.partial(jax.pmap, axis_name='ensemble')\n",
4633+
"def apply_model(state, batch: jnp.asarray):\n",
4634+
" image, label = batch\n",
4635+
" def loss_fn(params):\n",
4636+
" logits = MLPModel().apply({'params': params}, image)\n",
4637+
" loss = image_difference_loss(logits, label);\n",
4638+
" return loss, logits\n",
4639+
"\n",
4640+
" grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
4641+
" (loss, logits), grads = grad_fn(state.params)\n",
4642+
" return grads, loss\n",
4643+
"\n",
4644+
"@jax.pmap\n",
4645+
"def update_model(state, grads):\n",
4646+
" return state.apply_gradients(grads=grads)"
4647+
],
4648+
"metadata": {
4649+
"id": "QyGP4Fmf-q7q"
4650+
},
4651+
"execution_count": null,
4652+
"outputs": []
4653+
},
4654+
{
4655+
"cell_type": "code",
4656+
"source": [
4657+
"def train_epoch(state, train_ds, batch_size, rng):\n",
4658+
" train_ds_size = len(train_ds['image'])\n",
4659+
" steps_per_epoch = train_ds_size // batch_size\n",
4660+
"\n",
4661+
" perms = jax.random.permutation(rng, len(train_ds['image']))\n",
4662+
" perms = perms[:steps_per_epoch * batch_size]\n",
4663+
" perms = perms.reshape((steps_per_epoch, batch_size))\n",
4664+
"\n",
4665+
" epoch_loss = []\n",
4666+
"\n",
4667+
" for perm in perms:\n",
4668+
" batch_images = jax_utils.replicate(train_ds['image'][perm, ...])\n",
4669+
" batch_labels = jax_utils.replicate(train_ds['label'][perm, ...])\n",
4670+
" grads, loss = apply_model(state, batch_images, batch_labels)\n",
4671+
" state = update_model(state, grads)\n",
4672+
" epoch_loss.append(jax_utils.unreplicate(loss))\n",
4673+
" train_loss = np.mean(epoch_loss)\n",
4674+
" return state, train_loss"
4675+
],
4676+
"metadata": {
4677+
"id": "QQUj3Y3LA9A1"
4678+
},
4679+
"execution_count": null,
4680+
"outputs": []
4681+
},
4682+
{
4683+
"cell_type": "code",
4684+
"source": [
4685+
"train_ds, test_ds = get_datasets()\n",
4686+
"test_ds = jax_utils.replicate(test_ds)\n",
4687+
"rng = jax.random.PRNGKey(0)\n",
4688+
"\n",
4689+
"rng, init_rng = jax.random.split(rng)\n",
4690+
"\n",
4691+
"HxW, Channels = next(batches)[0].shape\n",
4692+
"state = create_train_state(jax.random.split(init_rng, jax.device_count()),(HxW, Channels),learning_rate)\n",
4693+
"\n",
4694+
"for epoch in range(1, num_epochs + 1):\n",
4695+
" rng, input_rng = jax.random.split(rng)\n",
4696+
" state, train_loss = train_epoch(state, train_ds, batch_size, input_rng)\n",
4697+
"\n",
4698+
" # _, test_loss = jax_utils.unreplicate(apply_model(state, test_ds['image'], test_ds['label']))\n",
4699+
"\n",
4700+
" logging.info('epoch:% 3d, train_loss: %.4f ' % (epoch, train_loss))"
4701+
],
4702+
"metadata": {
4703+
"id": "X-CttLscBnDQ"
4704+
},
4705+
"execution_count": null,
4706+
"outputs": []
4707+
},
4708+
{
4709+
"cell_type": "code",
4710+
"source": [
4711+
"correct = total = 0\n",
4712+
"for batch in ds.as_numpy_iterator():\n",
4713+
" preds = flax.jax_utils.pad_shard_unpad(get_preds)(\n",
4714+
" vs, batch['image'], min_device_batch=per_device_batch_size)\n",
4715+
" total += len(batch['image'])\n",
4716+
" correct += (batch['label'] == preds.argmax(axis=-1)).sum()"
4717+
],
4718+
"metadata": {
4719+
"id": "I_orMqbuD3LL"
4720+
},
4721+
"execution_count": null,
4722+
"outputs": []
4723+
},
4724+
{
4725+
"cell_type": "code",
4726+
"source": [
4727+
"def eval_step(metrics, variables, batch):\n",
4728+
" print('retrigger compilation', {k: v.shape for k, v in batch.items()})\n",
4729+
" preds = model.apply(variables, batch['image'])\n",
4730+
" correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum()\n",
4731+
" total = batch['mask'].sum()\n",
4732+
" return dict(\n",
4733+
" correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'),\n",
4734+
" total=metrics['total'] + jax.lax.psum(total, axis_name='batch'),\n",
4735+
" )\n",
4736+
"\n",
4737+
"eval_step = jax.pmap(eval_step, axis_name='batch')\n",
4738+
"eval_step = flax.jax_utils.pad_shard_unpad(\n",
4739+
" eval_step, static_argnums=(0, 1), static_return=True)"
4740+
],
4741+
"metadata": {
4742+
"id": "RxhJjRZLD5P-"
4743+
},
4744+
"execution_count": null,
4745+
"outputs": []
45864746
}
45874747
]
45884748
}

0 commit comments

Comments
 (0)