@@ -61,13 +61,13 @@ public class NeuralNetwork {
61
61
private final float [] inputNormalizationMaxValues ;
62
62
private final int outputAxisLength ;
63
63
64
- public NeuralNetwork (
65
- BrainRegistry brainRegistry ,
64
+ private NeuralNetwork (
65
+ int outputAxisLength ,
66
66
SavedModelBundle model ,
67
67
InputNormalizationValues inputNormalizationValues )
68
68
{
69
69
this .model = model ;
70
- this .outputAxisLength = brainRegistry . outputAxisLength () ;
70
+ this .outputAxisLength = outputAxisLength ;
71
71
this .session = model .session ();
72
72
73
73
SignatureDef sigDef = model .metaGraphDef ().getSignatureDefMap ().get ("serving_default" );
@@ -98,36 +98,9 @@ public NeuralNetwork(
98
98
*/
99
99
public static NeuralNetwork loadBrain (BrainRegistry brainRegistry ) {
100
100
Path path = Path .of ("data" , "ai" ,"brains" , brainRegistry .name ());
101
- try {
102
- // Initialize normalization values
103
- // the normalization values are on a file named min_max_feature_normalization.csv inside the modelPath
104
- Path normalizationFilePath = Path .of ("data" , "ai" ,"brains" , brainRegistry .name (),
105
- "min_max_feature_normalization.csv" );
106
- InputNormalizationValues inputNormalizationValues =
107
- new InputNormalizationValues (new float [brainRegistry .inputAxisLength ()], new float [brainRegistry .inputAxisLength ()]);
108
- try (var reader = new BufferedReader (new FileReader (normalizationFilePath .toFile ()))) {
109
- String line ;
110
- int index ;
111
- while ((line = reader .readLine ()) != null ) {
112
- if (line .startsWith ("feature," )) {
113
- continue ;
114
- }
115
- String [] values = line .split ("," );
116
- index = Integer .parseInt (values [0 ]);
117
- inputNormalizationValues .minValues ()[index ] = Float .parseFloat (values [1 ]);
118
- inputNormalizationValues .maxValues ()[index ] = Float .parseFloat (values [2 ]);
119
- }
120
- } catch (IOException e ) {
121
- logger .warn ("Normalization file not found: " + e .getMessage (), e );
122
- throw new RuntimeException ("Failed to load TensorFlow model: " + e .getMessage (), e );
123
- }
124
-
125
- return new NeuralNetwork (brainRegistry , SavedModelBundle .load (path .toString (), "serve" ),
126
- inputNormalizationValues );
127
- } catch (Exception e ) {
128
- logger .error ("Failed to load model" , e );
129
- throw new RuntimeException ("Failed to load TensorFlow model: " + e .getMessage (), e );
130
- }
101
+ SavedModelBundle model = SavedModelBundle .load (path .toString (), "serve" );
102
+ InputNormalizationValues inputNormalizationValues = InputNormalizationValues .loadFile (path );
103
+ return new NeuralNetwork (brainRegistry .outputAxisLength (), model , inputNormalizationValues );
131
104
}
132
105
133
106
/**
0 commit comments