Skip to content

Commit d65b6ce

Browse files
authored
Fix end-to-end tests (#672)
1 parent db53257 commit d65b6ce

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

python/mlcroissant/mlcroissant/_src/datasets_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,9 @@ def test_load_from_huggingface():
193193
url = "https://huggingface.co/api/datasets/mnist/croissant"
194194
dataset = datasets.Dataset(url)
195195
has_one_record = False
196-
for record in dataset.records(record_set="record_set_mnist"):
197-
assert record["record_set_mnist/label"] == 7
198-
assert isinstance(record["record_set_mnist/image"], deps.PIL_Image.Image)
196+
for record in dataset.records(record_set="mnist"):
197+
assert record["mnist/label"] == 7
198+
assert isinstance(record["mnist/image"], deps.PIL_Image.Image)
199199
has_one_record = True
200200
break
201201
assert has_one_record, (

python/mlcroissant/recipes/tfds_croissant_builder.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@
169169
" \"recordSet\": [\n",
170170
" {\n",
171171
" \"@type\": \"ml:RecordSet\",\n",
172-
" \"name\": \"record_set_fashion_mnist\",\n",
172+
" \"name\": \"fashion_mnist\",\n",
173173
" \"description\": \"fashion_mnist - 'fashion_mnist' subset\\n\\nAdditional information:\\n- 2 splits: train, test\",\n",
174174
" \"field\": [\n",
175175
" {\n",
@@ -244,7 +244,7 @@
244244
" - [dataset(fashion_mnist)] Property \"https://schema.org/citation\" is recommended, but does not exist.\n",
245245
" - [dataset(fashion_mnist)] Property \"https://schema.org/license\" is recommended, but does not exist.\n",
246246
" - [dataset(fashion_mnist)] Property \"https://schema.org/version\" is recommended, but does not exist.\n",
247-
"WARNING:absl:Using custom data configuration record_set_fashion_mnist\n"
247+
"WARNING:absl:Using custom data configuration fashion_mnist\n"
248248
]
249249
}
250250
],
@@ -253,7 +253,7 @@
253253
"\n",
254254
"builder = tfds.core.dataset_builders.CroissantBuilder(\n",
255255
" jsonld=local_croissant_file,\n",
256-
" record_set_ids=[\"record_set_fashion_mnist\"],\n",
256+
" record_set_ids=[\"fashion_mnist\"],\n",
257257
" file_format='array_record',\n",
258258
" data_dir=data_dir,\n",
259259
")"
@@ -383,7 +383,7 @@
383383
"output_type": "stream",
384384
"text": [
385385
"\u001b[01;34m/tmp/croissant/fashion_mnist\u001b[0m\n",
386-
"└── \u001b[01;34mrecord_set_fashion_mnist\u001b[0m\n",
386+
"└── \u001b[01;fashion_mnist\u001b[0m\n",
387387
" └── \u001b[01;34m1.0.0\u001b[0m\n",
388388
" ├── dataset_info.json\n",
389389
" ├── fashion_mnist-default.array_record-00000-of-00001\n",
@@ -522,7 +522,7 @@
522522
" image = image.view(image.size()[0], -1).to(torch.float32)\n",
523523
" return self.classifier(image)\n",
524524
"\n",
525-
"shape = train[0][\"record_set_fashion_mnist/image\"].shape\n",
525+
"shape = train[0][\"fashion_mnist/image\"].shape\n",
526526
"num_classes = 10\n",
527527
"model = LinearClassifier(shape, num_classes)\n",
528528
"optimizer = torch.optim.Adam(model.parameters())\n",
@@ -531,8 +531,8 @@
531531
"print('Training...')\n",
532532
"model.train()\n",
533533
"for example in tqdm(train_loader):\n",
534-
" image = example['record_set_fashion_mnist/image']\n",
535-
" label = example['record_set_fashion_mnist/label']\n",
534+
" image = example['fashion_mnist/image']\n",
535+
" label = example['fashion_mnist/label']\n",
536536
" prediction = model(image)\n",
537537
" loss = loss_function(prediction, label)\n",
538538
" optimizer.zero_grad()\n",
@@ -544,8 +544,8 @@
544544
"num_examples = 0\n",
545545
"true_positives = 0\n",
546546
"for example in tqdm(test_loader):\n",
547-
" image = example['record_set_fashion_mnist/image']\n",
548-
" label = example['record_set_fashion_mnist/label']\n",
547+
" image = example['fashion_mnist/image']\n",
548+
" label = example['fashion_mnist/label']\n",
549549
" prediction = model(image)\n",
550550
" num_examples += image.shape[0]\n",
551551
" predicted_label = prediction.argmax(dim=1)\n",

0 commit comments

Comments
 (0)