Skip to content

Commit e0ea1cd

Browse files
committed
chore: input normalization now load itself, making the NN load process cleaner and simpler
1 parent 6b1b2e2 commit e0ea1cd

File tree

5 files changed

+76
-48
lines changed

5 files changed

+76
-48
lines changed

megamek/src/megamek/ai/neuralnetwork/BrainRegistry.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@
3030
/**
3131
* BrainRegistry is a record that holds the brain name, input axis length, and output axis length.
3232
* @param name The name of the brain
33-
* @param inputAxisLength The length of the input axis
3433
* @param outputAxisLength The length of the output axis
3534
* @author Luana Coppio
3635
*/
37-
public record BrainRegistry(String name, int inputAxisLength, int outputAxisLength) {
36+
public record BrainRegistry(String name, int outputAxisLength) {
3837
}

megamek/src/megamek/ai/neuralnetwork/InputNormalizationValues.java

+47
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,52 @@
2727
*/
2828
package megamek.ai.neuralnetwork;
2929

30+
import java.io.BufferedReader;
31+
import java.io.FileReader;
32+
import java.io.IOException;
33+
import java.nio.file.Path;
34+
import java.util.ArrayList;
35+
import java.util.List;
36+
3037
public record InputNormalizationValues(float[] minValues, float[] maxValues) {
38+
public static InputNormalizationValues loadFile(Path path) {
39+
List<Float> minValuesList = new ArrayList<>();
40+
List<Float> maxValuesList = new ArrayList<>();
41+
float[] minValuesTemp;
42+
float[] maxValuesTemp;
43+
44+
int inputSize = 0;
45+
// Initialize normalization values
46+
// the normalization values are on a file named min_max_feature_normalization.csv inside the model folder
47+
Path normalizationFilePath = Path.of(path.toString(), "min_max_feature_normalization.csv");
48+
try (var reader = new BufferedReader(new FileReader(normalizationFilePath.toFile()))) {
49+
String line;
50+
while ((line = reader.readLine()) != null) {
51+
if (line.startsWith("feature,")) {
52+
continue;
53+
}
54+
String[] values = line.split(",");
55+
if (values.length != 3) {
56+
// This probably means that it reached the end of the file, but we need to throw an exception here
57+
// to avoid using invalid values because otherwise it is impossible to run.
58+
throw new IllegalArgumentException("Invalid line in normalization file: " + line);
59+
}
60+
minValuesList.add(Float.parseFloat(values[1]));
61+
maxValuesList.add(Float.parseFloat(values[2]));
62+
inputSize++;
63+
}
64+
65+
minValuesTemp = new float[minValuesList.size()];
66+
maxValuesTemp = new float[maxValuesList.size()];
67+
68+
for (int i = 0; i < inputSize; i++) {
69+
minValuesTemp[i] = minValuesList.get(i);
70+
maxValuesTemp[i] = maxValuesList.get(i);
71+
}
72+
} catch (IOException e) {
73+
throw new RuntimeException("Failed to load TensorFlow model: " + e.getMessage(), e);
74+
}
75+
76+
return new InputNormalizationValues(minValuesTemp, maxValuesTemp);
77+
}
3178
}

megamek/src/megamek/ai/neuralnetwork/NeuralNetwork.java

+6-33
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ public class NeuralNetwork {
6161
private final float[] inputNormalizationMaxValues;
6262
private final int outputAxisLength;
6363

64-
public NeuralNetwork(
65-
BrainRegistry brainRegistry,
64+
private NeuralNetwork(
65+
int outputAxisLength,
6666
SavedModelBundle model,
6767
InputNormalizationValues inputNormalizationValues)
6868
{
6969
this.model = model;
70-
this.outputAxisLength = brainRegistry.outputAxisLength();
70+
this.outputAxisLength = outputAxisLength;
7171
this.session = model.session();
7272

7373
SignatureDef sigDef = model.metaGraphDef().getSignatureDefMap().get("serving_default");
@@ -98,36 +98,9 @@ public NeuralNetwork(
9898
*/
9999
public static NeuralNetwork loadBrain(BrainRegistry brainRegistry) {
100100
Path path = Path.of("data", "ai","brains", brainRegistry.name());
101-
try {
102-
// Initialize normalization values
103-
// the normalization values are on a file named min_max_feature_normalization.csv inside the modelPath
104-
Path normalizationFilePath = Path.of("data", "ai","brains", brainRegistry.name(),
105-
"min_max_feature_normalization.csv");
106-
InputNormalizationValues inputNormalizationValues =
107-
new InputNormalizationValues(new float[brainRegistry.inputAxisLength()], new float[brainRegistry.inputAxisLength()]);
108-
try (var reader = new BufferedReader(new FileReader(normalizationFilePath.toFile()))) {
109-
String line;
110-
int index;
111-
while ((line = reader.readLine()) != null) {
112-
if (line.startsWith("feature,")) {
113-
continue;
114-
}
115-
String[] values = line.split(",");
116-
index = Integer.parseInt(values[0]);
117-
inputNormalizationValues.minValues()[index] = Float.parseFloat(values[1]);
118-
inputNormalizationValues.maxValues()[index] = Float.parseFloat(values[2]);
119-
}
120-
} catch (IOException e) {
121-
logger.warn("Normalization file not found: " + e.getMessage(), e);
122-
throw new RuntimeException("Failed to load TensorFlow model: " + e.getMessage(), e);
123-
}
124-
125-
return new NeuralNetwork(brainRegistry, SavedModelBundle.load(path.toString(), "serve"),
126-
inputNormalizationValues);
127-
} catch (Exception e) {
128-
logger.error("Failed to load model", e);
129-
throw new RuntimeException("Failed to load TensorFlow model: " + e.getMessage(), e);
130-
}
101+
SavedModelBundle model = SavedModelBundle.load(path.toString(), "serve");
102+
InputNormalizationValues inputNormalizationValues = InputNormalizationValues.loadFile(path);
103+
return new NeuralNetwork(brainRegistry.outputAxisLength(), model, inputNormalizationValues);
131104
}
132105

133106
/**

megamek/src/megamek/common/CubeCoords.java

+21-12
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
/*
2-
* Copyright (c) 2025 - The MegaMek Team. All Rights Reserved.
2+
* Copyright (C) 2025 The MegaMek Team. All Rights Reserved.
33
*
4-
* This file is part of MegaMek.
4+
* This file is part of MegaMek.
55
*
6-
* MekHQ is free software: you can redistribute it and/or modify
7-
* it under the terms of the GNU General Public License as published by
8-
* the Free Software Foundation, either version 3 of the License, or
9-
* (at your option) any later version.
6+
* MegaMek is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU General Public License (GPL),
8+
* version 2 or (at your option) any later version,
9+
* as published by the Free Software Foundation.
1010
*
11-
* MekHQ is distributed in the hope that it will be useful,
12-
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13-
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14-
* GNU General Public License for more details.
11+
* MegaMek is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty
13+
* of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
14+
* See the GNU General Public License for more details.
1515
*
16-
* You should have received a copy of the GNU General Public License
17-
* along with MekHQ. If not, see <http://www.gnu.org/licenses/>.
16+
* A copy of the GPL should have been included with this project;
17+
* if not, see <https://www.gnu.org/licenses/>.
18+
*
19+
* NOTICE: The MegaMek organization is a non-profit group of volunteers
20+
* creating free software for the BattleTech community.
21+
*
22+
* MechWarrior, BattleMech, `Mech and AeroTech are registered trademarks
23+
* of The Topps Company, Inc. All Rights Reserved.
24+
*
25+
* Catalyst Game Labs and the Catalyst Game Labs logo are trademarks of
26+
* InMediaRes Productions, LLC.
1827
*/
1928
package megamek.common;
2029

megamek/src/megamek/utilities/CasparUtilities.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public class CasparUtilities {
4141

4242
public static void main(String[] args) throws IOException {
4343
NeuralNetwork.testTensorFlow();
44-
NeuralNetwork neuralNetwork = NeuralNetwork.loadBrain(new BrainRegistry("default", 55, 3));
44+
NeuralNetwork neuralNetwork = NeuralNetwork.loadBrain(new BrainRegistry("default", 3));
4545
float[] x_test = new float[entry.length];
4646
for (int i = 0; i < entry.length; i++) {
4747
x_test[i] = (float) entry[i];

0 commit comments

Comments
 (0)