Skip to content

Commit eec37c0

Browse files
committed
Checked everything works properly after refactoring TorchSOM package
1 parent 79da14a commit eec37c0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+66
-619
lines changed

notebooks/boston_housing.ipynb

Lines changed: 19 additions & 12 deletions
Large diffs are not rendered by default.

notebooks/energy_efficiency.ipynb

Lines changed: 19 additions & 12 deletions
Large diffs are not rendered by default.

notebooks/iris.ipynb

Lines changed: 15 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
"import matplotlib.pyplot as plt\n",
2020
"import torch\n",
2121
"\n",
22-
"from torchsom.core import TorchSOM\n",
23-
"from torchsom.plotting import SOMVisualizer, VisualizationConfig\n",
22+
"from torchsom.core import SOM\n",
23+
"from torchsom.visualization import SOMVisualizer, VisualizationConfig\n",
2424
"\n",
2525
"from sklearn.preprocessing import StandardScaler\n",
2626
"from sklearn.neural_network import MLPClassifier\n",
@@ -99,88 +99,6 @@
9999
"outputs": [
100100
{
101101
"data": {
102-
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
103-
"columns": [
104-
{
105-
"name": "index",
106-
"rawType": "int64",
107-
"type": "integer"
108-
},
109-
{
110-
"name": "Sepal Length",
111-
"rawType": "float64",
112-
"type": "float"
113-
},
114-
{
115-
"name": "Sepal Width",
116-
"rawType": "float64",
117-
"type": "float"
118-
},
119-
{
120-
"name": "Petal Length",
121-
"rawType": "float64",
122-
"type": "float"
123-
},
124-
{
125-
"name": "Petal Width",
126-
"rawType": "float64",
127-
"type": "float"
128-
},
129-
{
130-
"name": "Species",
131-
"rawType": "int64",
132-
"type": "integer"
133-
}
134-
],
135-
"conversionMethod": "pd.DataFrame",
136-
"ref": "6214df72-d361-45b1-b0b4-73fbb4dbb243",
137-
"rows": [
138-
[
139-
"0",
140-
"-0.9006811702978088",
141-
"1.0320572244889565",
142-
"-1.3412724047598314",
143-
"-1.3129767272601454",
144-
"1"
145-
],
146-
[
147-
"1",
148-
"-1.1430169111851105",
149-
"-0.12495760117130933",
150-
"-1.3412724047598314",
151-
"-1.3129767272601454",
152-
"1"
153-
],
154-
[
155-
"2",
156-
"-1.3853526520724133",
157-
"0.3378483290927974",
158-
"-1.3981381087490836",
159-
"-1.3129767272601454",
160-
"1"
161-
],
162-
[
163-
"3",
164-
"-1.5065205225160652",
165-
"0.10644536396074403",
166-
"-1.284406700770579",
167-
"-1.3129767272601454",
168-
"1"
169-
],
170-
[
171-
"4",
172-
"-1.0218490407414595",
173-
"1.2634601896210098",
174-
"-1.3412724047598314",
175-
"-1.3129767272601454",
176-
"1"
177-
]
178-
],
179-
"shape": {
180-
"columns": 5,
181-
"rows": 5
182-
}
183-
},
184102
"text/html": [
185103
"<div>\n",
186104
"<style scoped>\n",
@@ -277,112 +195,6 @@
277195
"outputs": [
278196
{
279197
"data": {
280-
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
281-
"columns": [
282-
{
283-
"name": "index",
284-
"rawType": "object",
285-
"type": "string"
286-
},
287-
{
288-
"name": "Sepal Length",
289-
"rawType": "float64",
290-
"type": "float"
291-
},
292-
{
293-
"name": "Sepal Width",
294-
"rawType": "float64",
295-
"type": "float"
296-
},
297-
{
298-
"name": "Petal Length",
299-
"rawType": "float64",
300-
"type": "float"
301-
},
302-
{
303-
"name": "Petal Width",
304-
"rawType": "float64",
305-
"type": "float"
306-
},
307-
{
308-
"name": "Species",
309-
"rawType": "float64",
310-
"type": "float"
311-
}
312-
],
313-
"conversionMethod": "pd.DataFrame",
314-
"ref": "ef5af901-e677-4bb2-823f-dab750e668ad",
315-
"rows": [
316-
[
317-
"count",
318-
"150.0",
319-
"150.0",
320-
"150.0",
321-
"150.0",
322-
"150.0"
323-
],
324-
[
325-
"mean",
326-
"-4.736951571734001e-16",
327-
"-6.631732200427602e-16",
328-
"3.315866100213801e-16",
329-
"-2.842170943040401e-16",
330-
"2.0"
331-
],
332-
[
333-
"std",
334-
"1.0033500931359767",
335-
"1.0033500931359767",
336-
"1.0033500931359765",
337-
"1.0033500931359767",
338-
"0.8192319205190405"
339-
],
340-
[
341-
"min",
342-
"-1.870024133847019",
343-
"-2.438987252491841",
344-
"-1.5687352207168408",
345-
"-1.4444496972795189",
346-
"1.0"
347-
],
348-
[
349-
"25%",
350-
"-0.9006811702978088",
351-
"-0.587763531435416",
352-
"-1.2275409967813267",
353-
"-1.1815037572407716",
354-
"1.0"
355-
],
356-
[
357-
"50%",
358-
"-0.05250607719224957",
359-
"-0.12495760117130933",
360-
"0.33626586292311245",
361-
"0.13322594295296525",
362-
"2.0"
363-
],
364-
[
365-
"75%",
366-
"0.6745011454696588",
367-
"0.5692512942248498",
368-
"0.7627586428425047",
369-
"0.7905907930498337",
370-
"3.0"
371-
],
372-
[
373-
"max",
374-
"2.4920192021244283",
375-
"3.1146839106774356",
376-
"1.7863413146490472",
377-
"1.7109015831854495",
378-
"3.0"
379-
]
380-
],
381-
"shape": {
382-
"columns": 5,
383-
"rows": 8
384-
}
385-
},
386198
"text/html": [
387199
"<div>\n",
388200
"<style scoped>\n",
@@ -589,13 +401,13 @@
589401
"metadata": {},
590402
"outputs": [],
591403
"source": [
592-
"som = TorchSOM(\n",
404+
"som = SOM(\n",
593405
" x=25,\n",
594406
" y=15,\n",
595407
" sigma=1.45,\n",
596408
" learning_rate=0.95,\n",
597409
" neighborhood_order=3,\n",
598-
" epochs=100,\n",
410+
" epochs=50,\n",
599411
" batch_size=16,\n",
600412
" topology=\"rectangular\",\n",
601413
" distance_function=\"euclidean\",\n",
@@ -617,6 +429,7 @@
617429
"source": [
618430
"som.initialize_weights(\n",
619431
" data=train_features,\n",
432+
" mode=som.initialization_mode\n",
620433
")"
621434
]
622435
},
@@ -629,7 +442,7 @@
629442
"name": "stderr",
630443
"output_type": "stream",
631444
"text": [
632-
"Training SOM: 100%|██████████| 100/100 [00:02<00:00, 39.84epoch/s]\n"
445+
"Training SOM: 100%|██████████| 50/50 [00:04<00:00, 11.38epoch/s]\n"
633446
]
634447
}
635448
],
@@ -734,7 +547,7 @@
734547
" query_sample=test_feature,\n",
735548
" historical_samples=train_features,\n",
736549
" historical_outputs=train_targets,\n",
737-
" min_buffer_threshold=30, # Collect 20 historical samples to train a model\n",
550+
" min_buffer_threshold=35, # Collect 20 historical samples to train a model\n",
738551
" bmus_idx_map=bmus_idx_map,\n",
739552
" )\n",
740553
" \n",
@@ -841,11 +654,18 @@
841654
"plt.tight_layout()\n",
842655
"plt.show()"
843656
]
657+
},
658+
{
659+
"cell_type": "code",
660+
"execution_count": null,
661+
"metadata": {},
662+
"outputs": [],
663+
"source": []
844664
}
845665
],
846666
"metadata": {
847667
"kernelspec": {
848-
"display_name": "Python 3",
668+
"display_name": ".torchsom_env",
849669
"language": "python",
850670
"name": "python3"
851671
},
13.6 KB
-596 Bytes
-1.61 KB
2.95 KB
1.19 KB
1.89 KB
-1.23 KB

0 commit comments

Comments
 (0)