|
169 | 169 | " \"recordSet\": [\n", |
170 | 170 | " {\n", |
171 | 171 | " \"@type\": \"ml:RecordSet\",\n", |
172 | | - " \"name\": \"record_set_fashion_mnist\",\n", |
| 172 | + " \"name\": \"fashion_mnist\",\n", |
173 | 173 | " \"description\": \"fashion_mnist - 'fashion_mnist' subset\\n\\nAdditional information:\\n- 2 splits: train, test\",\n", |
174 | 174 | " \"field\": [\n", |
175 | 175 | " {\n", |
|
244 | 244 | " - [dataset(fashion_mnist)] Property \"https://schema.org/citation\" is recommended, but does not exist.\n", |
245 | 245 | " - [dataset(fashion_mnist)] Property \"https://schema.org/license\" is recommended, but does not exist.\n", |
246 | 246 | " - [dataset(fashion_mnist)] Property \"https://schema.org/version\" is recommended, but does not exist.\n", |
247 | | - "WARNING:absl:Using custom data configuration record_set_fashion_mnist\n" |
| 247 | + "WARNING:absl:Using custom data configuration fashion_mnist\n" |
248 | 248 | ] |
249 | 249 | } |
250 | 250 | ], |
|
253 | 253 | "\n", |
254 | 254 | "builder = tfds.core.dataset_builders.CroissantBuilder(\n", |
255 | 255 | " jsonld=local_croissant_file,\n", |
256 | | - " record_set_ids=[\"record_set_fashion_mnist\"],\n", |
| 256 | + " record_set_ids=[\"fashion_mnist\"],\n", |
257 | 257 | " file_format='array_record',\n", |
258 | 258 | " data_dir=data_dir,\n", |
259 | 259 | ")" |
|
383 | 383 | "output_type": "stream", |
384 | 384 | "text": [ |
385 | 385 | "\u001b[01;34m/tmp/croissant/fashion_mnist\u001b[0m\n", |
386 | | - "└── \u001b[01;34mrecord_set_fashion_mnist\u001b[0m\n", |
| 386 | + "└── \u001b[01;fashion_mnist\u001b[0m\n", |
387 | 387 | " └── \u001b[01;34m1.0.0\u001b[0m\n", |
388 | 388 | " ├── dataset_info.json\n", |
389 | 389 | " ├── fashion_mnist-default.array_record-00000-of-00001\n", |
|
522 | 522 | " image = image.view(image.size()[0], -1).to(torch.float32)\n", |
523 | 523 | " return self.classifier(image)\n", |
524 | 524 | "\n", |
525 | | - "shape = train[0][\"record_set_fashion_mnist/image\"].shape\n", |
| 525 | + "shape = train[0][\"fashion_mnist/image\"].shape\n", |
526 | 526 | "num_classes = 10\n", |
527 | 527 | "model = LinearClassifier(shape, num_classes)\n", |
528 | 528 | "optimizer = torch.optim.Adam(model.parameters())\n", |
|
531 | 531 | "print('Training...')\n", |
532 | 532 | "model.train()\n", |
533 | 533 | "for example in tqdm(train_loader):\n", |
534 | | - " image = example['record_set_fashion_mnist/image']\n", |
535 | | - " label = example['record_set_fashion_mnist/label']\n", |
| 534 | + " image = example['fashion_mnist/image']\n", |
| 535 | + " label = example['fashion_mnist/label']\n", |
536 | 536 | " prediction = model(image)\n", |
537 | 537 | " loss = loss_function(prediction, label)\n", |
538 | 538 | " optimizer.zero_grad()\n", |
|
544 | 544 | "num_examples = 0\n", |
545 | 545 | "true_positives = 0\n", |
546 | 546 | "for example in tqdm(test_loader):\n", |
547 | | - " image = example['record_set_fashion_mnist/image']\n", |
548 | | - " label = example['record_set_fashion_mnist/label']\n", |
| 547 | + " image = example['fashion_mnist/image']\n", |
| 548 | + " label = example['fashion_mnist/label']\n", |
549 | 549 | " prediction = model(image)\n", |
550 | 550 | " num_examples += image.shape[0]\n", |
551 | 551 | " predicted_label = prediction.argmax(dim=1)\n", |
|
0 commit comments