|
19 | 19 | "import matplotlib.pyplot as plt\n", |
20 | 20 | "import torch\n", |
21 | 21 | "\n", |
22 | | - "from torchsom.core import TorchSOM\n", |
23 | | - "from torchsom.plotting import SOMVisualizer, VisualizationConfig\n", |
| 22 | + "from torchsom.core import SOM\n", |
| 23 | + "from torchsom.visualization import SOMVisualizer, VisualizationConfig\n", |
24 | 24 | "\n", |
25 | 25 | "from sklearn.preprocessing import StandardScaler\n", |
26 | 26 | "from sklearn.neural_network import MLPClassifier\n", |
|
99 | 99 | "outputs": [ |
100 | 100 | { |
101 | 101 | "data": { |
102 | | - "application/vnd.microsoft.datawrangler.viewer.v0+json": { |
103 | | - "columns": [ |
104 | | - { |
105 | | - "name": "index", |
106 | | - "rawType": "int64", |
107 | | - "type": "integer" |
108 | | - }, |
109 | | - { |
110 | | - "name": "Sepal Length", |
111 | | - "rawType": "float64", |
112 | | - "type": "float" |
113 | | - }, |
114 | | - { |
115 | | - "name": "Sepal Width", |
116 | | - "rawType": "float64", |
117 | | - "type": "float" |
118 | | - }, |
119 | | - { |
120 | | - "name": "Petal Length", |
121 | | - "rawType": "float64", |
122 | | - "type": "float" |
123 | | - }, |
124 | | - { |
125 | | - "name": "Petal Width", |
126 | | - "rawType": "float64", |
127 | | - "type": "float" |
128 | | - }, |
129 | | - { |
130 | | - "name": "Species", |
131 | | - "rawType": "int64", |
132 | | - "type": "integer" |
133 | | - } |
134 | | - ], |
135 | | - "conversionMethod": "pd.DataFrame", |
136 | | - "ref": "6214df72-d361-45b1-b0b4-73fbb4dbb243", |
137 | | - "rows": [ |
138 | | - [ |
139 | | - "0", |
140 | | - "-0.9006811702978088", |
141 | | - "1.0320572244889565", |
142 | | - "-1.3412724047598314", |
143 | | - "-1.3129767272601454", |
144 | | - "1" |
145 | | - ], |
146 | | - [ |
147 | | - "1", |
148 | | - "-1.1430169111851105", |
149 | | - "-0.12495760117130933", |
150 | | - "-1.3412724047598314", |
151 | | - "-1.3129767272601454", |
152 | | - "1" |
153 | | - ], |
154 | | - [ |
155 | | - "2", |
156 | | - "-1.3853526520724133", |
157 | | - "0.3378483290927974", |
158 | | - "-1.3981381087490836", |
159 | | - "-1.3129767272601454", |
160 | | - "1" |
161 | | - ], |
162 | | - [ |
163 | | - "3", |
164 | | - "-1.5065205225160652", |
165 | | - "0.10644536396074403", |
166 | | - "-1.284406700770579", |
167 | | - "-1.3129767272601454", |
168 | | - "1" |
169 | | - ], |
170 | | - [ |
171 | | - "4", |
172 | | - "-1.0218490407414595", |
173 | | - "1.2634601896210098", |
174 | | - "-1.3412724047598314", |
175 | | - "-1.3129767272601454", |
176 | | - "1" |
177 | | - ] |
178 | | - ], |
179 | | - "shape": { |
180 | | - "columns": 5, |
181 | | - "rows": 5 |
182 | | - } |
183 | | - }, |
184 | 102 | "text/html": [ |
185 | 103 | "<div>\n", |
186 | 104 | "<style scoped>\n", |
|
277 | 195 | "outputs": [ |
278 | 196 | { |
279 | 197 | "data": { |
280 | | - "application/vnd.microsoft.datawrangler.viewer.v0+json": { |
281 | | - "columns": [ |
282 | | - { |
283 | | - "name": "index", |
284 | | - "rawType": "object", |
285 | | - "type": "string" |
286 | | - }, |
287 | | - { |
288 | | - "name": "Sepal Length", |
289 | | - "rawType": "float64", |
290 | | - "type": "float" |
291 | | - }, |
292 | | - { |
293 | | - "name": "Sepal Width", |
294 | | - "rawType": "float64", |
295 | | - "type": "float" |
296 | | - }, |
297 | | - { |
298 | | - "name": "Petal Length", |
299 | | - "rawType": "float64", |
300 | | - "type": "float" |
301 | | - }, |
302 | | - { |
303 | | - "name": "Petal Width", |
304 | | - "rawType": "float64", |
305 | | - "type": "float" |
306 | | - }, |
307 | | - { |
308 | | - "name": "Species", |
309 | | - "rawType": "float64", |
310 | | - "type": "float" |
311 | | - } |
312 | | - ], |
313 | | - "conversionMethod": "pd.DataFrame", |
314 | | - "ref": "ef5af901-e677-4bb2-823f-dab750e668ad", |
315 | | - "rows": [ |
316 | | - [ |
317 | | - "count", |
318 | | - "150.0", |
319 | | - "150.0", |
320 | | - "150.0", |
321 | | - "150.0", |
322 | | - "150.0" |
323 | | - ], |
324 | | - [ |
325 | | - "mean", |
326 | | - "-4.736951571734001e-16", |
327 | | - "-6.631732200427602e-16", |
328 | | - "3.315866100213801e-16", |
329 | | - "-2.842170943040401e-16", |
330 | | - "2.0" |
331 | | - ], |
332 | | - [ |
333 | | - "std", |
334 | | - "1.0033500931359767", |
335 | | - "1.0033500931359767", |
336 | | - "1.0033500931359765", |
337 | | - "1.0033500931359767", |
338 | | - "0.8192319205190405" |
339 | | - ], |
340 | | - [ |
341 | | - "min", |
342 | | - "-1.870024133847019", |
343 | | - "-2.438987252491841", |
344 | | - "-1.5687352207168408", |
345 | | - "-1.4444496972795189", |
346 | | - "1.0" |
347 | | - ], |
348 | | - [ |
349 | | - "25%", |
350 | | - "-0.9006811702978088", |
351 | | - "-0.587763531435416", |
352 | | - "-1.2275409967813267", |
353 | | - "-1.1815037572407716", |
354 | | - "1.0" |
355 | | - ], |
356 | | - [ |
357 | | - "50%", |
358 | | - "-0.05250607719224957", |
359 | | - "-0.12495760117130933", |
360 | | - "0.33626586292311245", |
361 | | - "0.13322594295296525", |
362 | | - "2.0" |
363 | | - ], |
364 | | - [ |
365 | | - "75%", |
366 | | - "0.6745011454696588", |
367 | | - "0.5692512942248498", |
368 | | - "0.7627586428425047", |
369 | | - "0.7905907930498337", |
370 | | - "3.0" |
371 | | - ], |
372 | | - [ |
373 | | - "max", |
374 | | - "2.4920192021244283", |
375 | | - "3.1146839106774356", |
376 | | - "1.7863413146490472", |
377 | | - "1.7109015831854495", |
378 | | - "3.0" |
379 | | - ] |
380 | | - ], |
381 | | - "shape": { |
382 | | - "columns": 5, |
383 | | - "rows": 8 |
384 | | - } |
385 | | - }, |
386 | 198 | "text/html": [ |
387 | 199 | "<div>\n", |
388 | 200 | "<style scoped>\n", |
|
589 | 401 | "metadata": {}, |
590 | 402 | "outputs": [], |
591 | 403 | "source": [ |
592 | | - "som = TorchSOM(\n", |
| 404 | + "som = SOM(\n", |
593 | 405 | " x=25,\n", |
594 | 406 | " y=15,\n", |
595 | 407 | " sigma=1.45,\n", |
596 | 408 | " learning_rate=0.95,\n", |
597 | 409 | " neighborhood_order=3,\n", |
598 | | - " epochs=100,\n", |
| 410 | + " epochs=50,\n", |
599 | 411 | " batch_size=16,\n", |
600 | 412 | " topology=\"rectangular\",\n", |
601 | 413 | " distance_function=\"euclidean\",\n", |
|
617 | 429 | "source": [ |
618 | 430 | "som.initialize_weights(\n", |
619 | 431 | " data=train_features,\n", |
| 432 | + " mode=som.initialization_mode\n", |
620 | 433 | ")" |
621 | 434 | ] |
622 | 435 | }, |
|
629 | 442 | "name": "stderr", |
630 | 443 | "output_type": "stream", |
631 | 444 | "text": [ |
632 | | - "Training SOM: 100%|██████████| 100/100 [00:02<00:00, 39.84epoch/s]\n" |
| 445 | + "Training SOM: 100%|██████████| 50/50 [00:04<00:00, 11.38epoch/s]\n" |
633 | 446 | ] |
634 | 447 | } |
635 | 448 | ], |
|
734 | 547 | " query_sample=test_feature,\n", |
735 | 548 | " historical_samples=train_features,\n", |
736 | 549 | " historical_outputs=train_targets,\n", |
737 | | - " min_buffer_threshold=30, # Collect 20 historical samples to train a model\n", |
| 550 | + " min_buffer_threshold=35, # Collect 20 historical samples to train a model\n", |
738 | 551 | " bmus_idx_map=bmus_idx_map,\n", |
739 | 552 | " )\n", |
740 | 553 | " \n", |
|
841 | 654 | "plt.tight_layout()\n", |
842 | 655 | "plt.show()" |
843 | 656 | ] |
| 657 | + }, |
| 658 | + { |
| 659 | + "cell_type": "code", |
| 660 | + "execution_count": null, |
| 661 | + "metadata": {}, |
| 662 | + "outputs": [], |
| 663 | + "source": [] |
844 | 664 | } |
845 | 665 | ], |
846 | 666 | "metadata": { |
847 | 667 | "kernelspec": { |
848 | | - "display_name": "Python 3", |
| 668 | + "display_name": ".torchsom_env", |
849 | 669 | "language": "python", |
850 | 670 | "name": "python3" |
851 | 671 | }, |
|
0 commit comments