|
166 | 166 | } |
167 | 167 | ], |
168 | 168 | "conversionMethod": "pd.DataFrame", |
169 | | - "ref": "6429f9e5-3223-4fcb-9626-09775f5a173a", |
| 169 | + "ref": "b1c02221-fa87-4177-a697-a36be67e0a6a", |
170 | 170 | "rows": [ |
171 | 171 | [ |
172 | 172 | "0", |
|
502 | 502 | } |
503 | 503 | ], |
504 | 504 | "conversionMethod": "pd.DataFrame", |
505 | | - "ref": "29f23741-6c1b-4531-877f-93dd3454ef74", |
| 505 | + "ref": "cebf9ebc-1ab1-4f91-9ac9-329a1c4c71b7", |
506 | 506 | "rows": [ |
507 | 507 | [ |
508 | 508 | "count", |
|
1015 | 1015 | "name": "stderr", |
1016 | 1016 | "output_type": "stream", |
1017 | 1017 | "text": [ |
1018 | | - "Training SOM: 100%|██████████| 100/100 [00:02<00:00, 38.53epoch/s]\n" |
| 1018 | + "Training SOM: 5%|▌ | 5/100 [00:00<00:05, 17.79epoch/s]" |
| 1019 | + ] |
| 1020 | + }, |
| 1021 | + { |
| 1022 | + "name": "stderr", |
| 1023 | + "output_type": "stream", |
| 1024 | + "text": [ |
| 1025 | + "Training SOM: 100%|██████████| 100/100 [00:06<00:00, 16.15epoch/s]\n" |
1019 | 1026 | ] |
1020 | 1027 | } |
1021 | 1028 | ], |
|
1110 | 1117 | "outputs": [], |
1111 | 1118 | "source": [ |
1112 | 1119 | "predictions = []\n", |
| 1120 | + "bmus_idx_map = som.build_bmus_data_map(\n", |
| 1121 | + " data=train_features,\n", |
| 1122 | + " return_indices=True, # False means we want the features of each sample and not the indices\n", |
| 1123 | + ")\n", |
1113 | 1124 | "for idx, (test_feature, test_target) in enumerate(zip(test_features, test_targets)):\n", |
1114 | 1125 | " \n", |
1115 | 1126 | " collected_features, collected_targets = som.collect_samples(\n", |
1116 | 1127 | " query_sample=test_feature,\n", |
1117 | 1128 | " historical_samples=train_features,\n", |
1118 | 1129 | " historical_outputs=train_targets,\n", |
1119 | | - " min_buffer_threshold=30 # Collect 30 historical samples to train a model\n", |
| 1130 | + " min_buffer_threshold=30, # Collect 30 historical samples to train a model\n", |
| 1131 | + " bmus_idx_map=bmus_idx_map,\n", |
1120 | 1132 | " )\n", |
1121 | 1133 | " \n", |
1122 | 1134 | " X = collected_features.numpy()\n", |
|
1226 | 1238 | ], |
1227 | 1239 | "metadata": { |
1228 | 1240 | "kernelspec": { |
1229 | | - "display_name": ".venv", |
| 1241 | + "display_name": ".venv_sensing", |
1230 | 1242 | "language": "python", |
1231 | 1243 | "name": "python3" |
1232 | 1244 | }, |
|
0 commit comments