@@ -40,6 +40,8 @@ def __init__(self, fname):
40
40
raise FnetoutError ('Invalid running mode specification.' )
41
41
42
42
output = fnetoutfile ['fnetout' ]['output' ]
43
+
44
+ # read number of datapoints
43
45
self ._ndatapoints = output .attrs .get ('ndatapoints' )
44
46
if len (self ._ndatapoints ) == 1 :
45
47
# number of datapoints stored in array of size 1
@@ -48,6 +50,28 @@ def __init__(self, fname):
48
50
msg = "Error while reading fnetout file '" + self ._fname + \
49
51
"'. Unrecognized number of datapoints obtained."
50
52
raise FnetoutError (msg )
53
+
54
+ # read number of system-wide targets
55
+ self ._nglobaltargets = output .attrs .get ('nglobaltargets' )
56
+ if len (self ._nglobaltargets ) == 1 :
57
+ # number of system-wide targets stored in array of size 1
58
+ self ._nglobaltargets = self ._nglobaltargets [0 ]
59
+ else :
60
+ msg = "Error while reading fnetout file '" + self ._fname + \
61
+ "'. Unrecognized number of global targets obtained."
62
+ raise FnetoutError (msg )
63
+
64
+ # read number of atomic targets
65
+ self ._natomictargets = output .attrs .get ('natomictargets' )
66
+ if len (self ._natomictargets ) == 1 :
67
+ # number of atomic targets stored in array of size 1
68
+ self ._natomictargets = self ._natomictargets [0 ]
69
+ else :
70
+ msg = "Error while reading fnetout file '" + self ._fname + \
71
+ "'. Unrecognized number of atomic targets obtained."
72
+ raise FnetoutError (msg )
73
+
74
+ # read force specification
51
75
self ._tforces = output .attrs .get ('tforces' )
52
76
# account for legacy files where no force entry is present
53
77
if self ._tforces is None :
@@ -60,18 +84,6 @@ def __init__(self, fname):
60
84
"'. Unrecognized force specification obtained."
61
85
raise FnetoutError (msg )
62
86
63
- self ._targettype = \
64
- output .attrs .get ('targettype' ).decode ('UTF-8' ).strip ()
65
- if not self ._targettype in ('atomic' , 'global' ):
66
- raise FnetoutError ('Invalid running mode obtained.' )
67
-
68
- # get number of atomic or global predictions/targets
69
- self ._npredictions = np .shape (
70
- np .array (output ['datapoint1' ]['output' ]))[1 ]
71
-
72
- if self ._mode == 'validate' :
73
- self ._npredictions = int (self ._npredictions / 2 )
74
-
75
87
76
88
@property
77
89
def mode (self ):
@@ -100,16 +112,29 @@ def ndatapoints(self):
100
112
101
113
102
114
@property
103
- def targettype (self ):
104
- '''Defines property, providing the target type.
115
+ def nglobaltargets (self ):
116
+ '''Defines property, providing the number of system-wide targets.
117
+
118
+ Returns:
119
+
120
+ nglobaltargets (int): number of global targets per datapoint
121
+
122
+ '''
123
+
124
+ return self ._nglobaltargets
125
+
126
+
127
+ @property
128
+ def natomictargets (self ):
129
+ '''Defines property, providing the number of atomic targets.
105
130
106
131
Returns:
107
132
108
- targettype (str ): type of targets the network was trained on
133
+ natomictargets (int ): number of atomic targets per datapoint
109
134
110
135
'''
111
136
112
- return self ._targettype
137
+ return self ._natomictargets
113
138
114
139
115
140
@property
@@ -126,75 +151,134 @@ def tforces(self):
126
151
127
152
128
153
@property
129
- def predictions (self ):
130
- '''Defines property, providing the predictions of Fortnet.
154
+ def globalpredictions (self ):
155
+ '''Defines property, providing the system-wide predictions of Fortnet.
131
156
132
157
Returns:
133
158
134
- predictions (list or 2darray): predictions of the network
159
+ predictions (2darray): predictions of the network
135
160
136
161
'''
137
162
163
+ if not self ._nglobaltargets > 0 :
164
+ return None
165
+
166
+ predictions = np .empty ((self ._ndatapoints , self ._nglobaltargets ),
167
+ dtype = float )
168
+
138
169
with h5py .File (self ._fname , 'r' ) as fnetoutfile :
139
170
output = fnetoutfile ['fnetout' ]['output' ]
140
- if self ._targettype == 'atomic' :
141
- predictions = []
142
- for idata in range (self ._ndatapoints ):
143
- dataname = 'datapoint' + str (idata + 1 )
144
- if self ._mode == 'validate' :
145
- predictions .append (
146
- np .array (output [dataname ]['output' ],
147
- dtype = float )[:, :self ._npredictions ])
148
- else :
149
- predictions .append (
150
- np .array (output [dataname ]['output' ], dtype = float ))
151
- else :
152
- predictions = np .empty (
153
- (self ._ndatapoints , self ._npredictions ), dtype = float )
154
- for idata in range (self ._ndatapoints ):
155
- dataname = 'datapoint' + str (idata + 1 )
156
- if self ._mode == 'validate' :
157
- predictions [idata , :] = \
158
- np .array (output [dataname ]['output' ],
159
- dtype = float )[0 , :self ._npredictions ]
160
- else :
161
- predictions [idata , :] = \
162
- np .array (output [dataname ]['output' ],
163
- dtype = float )[0 , :]
171
+ for idata in range (self ._ndatapoints ):
172
+ dataname = 'datapoint' + str (idata + 1 )
173
+ predictions [idata , :] = np .array (
174
+ output [dataname ]['globalpredictions' ],
175
+ dtype = float )
164
176
165
177
return predictions
166
178
167
179
168
180
@property
169
- def targets (self ):
170
- '''Defines property, providing the targets during training.
181
+ def globalpredictions_atomic (self ):
182
+ '''Defines property, providing the (atom-resolved) system-wide
183
+ predictions of Fortnet.
171
184
172
185
Returns:
173
186
174
- targets (list or 2darray ): targets during training
187
+ predictions (list): predictions of the network
175
188
176
189
'''
177
190
178
- if self ._mode == 'predict' :
191
+ if not self ._nglobaltargets > 0 :
179
192
return None
180
193
194
+ predictions = []
195
+
181
196
with h5py .File (self ._fname , 'r' ) as fnetoutfile :
182
197
output = fnetoutfile ['fnetout' ]['output' ]
183
- if self ._targettype == 'atomic' :
184
- targets = []
185
- for idata in range (self ._ndatapoints ):
186
- dataname = 'datapoint' + str (idata + 1 )
187
- targets .append (
188
- np .array (output [dataname ]['output' ],
189
- dtype = float )[:, self ._npredictions :])
190
- else :
191
- targets = np .empty (
192
- (self ._ndatapoints , self ._npredictions ), dtype = float )
193
- for idata in range (self ._ndatapoints ):
194
- dataname = 'datapoint' + str (idata + 1 )
195
- targets [idata , :] = \
196
- np .array (output [dataname ]['output' ],
197
- dtype = float )[0 , self ._npredictions :]
198
+ for idata in range (self ._ndatapoints ):
199
+ dataname = 'datapoint' + str (idata + 1 )
200
+ predictions .append (
201
+ np .array (output [dataname ]['rawpredictions' ],
202
+ dtype = float )[:, 0 :self ._nglobaltargets ])
203
+
204
+ return predictions
205
+
206
+
207
+ @property
208
+ def atomicpredictions (self ):
209
+ '''Defines property, providing the atomic predictions of Fortnet.
210
+
211
+ Returns:
212
+
213
+ predictions (list): predictions of the network
214
+
215
+ '''
216
+
217
+ if not self ._natomictargets > 0 :
218
+ return None
219
+
220
+ predictions = []
221
+
222
+ with h5py .File (self ._fname , 'r' ) as fnetoutfile :
223
+ output = fnetoutfile ['fnetout' ]['output' ]
224
+ for idata in range (self ._ndatapoints ):
225
+ dataname = 'datapoint' + str (idata + 1 )
226
+ predictions .append (
227
+ np .array (output [dataname ]
228
+ ['rawpredictions' ], dtype = float )
229
+ [:, self ._nglobaltargets :])
230
+
231
+ return predictions
232
+
233
+
234
+ @property
235
+ def globaltargets (self ):
236
+ '''Defines property, providing the system-wide targets during training.
237
+
238
+ Returns:
239
+
240
+ targets (2darray): system-wide targets during training
241
+
242
+ '''
243
+
244
+ if self ._mode == 'predict' or self ._nglobaltargets == 0 :
245
+ return None
246
+
247
+ targets = np .empty ((self ._ndatapoints , self ._nglobaltargets ),
248
+ dtype = float )
249
+
250
+ with h5py .File (self ._fname , 'r' ) as fnetoutfile :
251
+ output = fnetoutfile ['fnetout' ]['output' ]
252
+ for idata in range (self ._ndatapoints ):
253
+ dataname = 'datapoint' + str (idata + 1 )
254
+ targets [idata , :] = np .array (
255
+ output [dataname ]['globaltargets' ],
256
+ dtype = float )
257
+
258
+ return targets
259
+
260
+
261
+ @property
262
+ def atomictargets (self ):
263
+ '''Defines property, providing the atomic targets during training.
264
+
265
+ Returns:
266
+
267
+ targets (list): atomic targets during training
268
+
269
+ '''
270
+
271
+ if self ._mode == 'predict' or self ._natomictargets == 0 :
272
+ return None
273
+
274
+ targets = []
275
+
276
+ with h5py .File (self ._fname , 'r' ) as fnetoutfile :
277
+ output = fnetoutfile ['fnetout' ]['output' ]
278
+ for idata in range (self ._ndatapoints ):
279
+ dataname = 'datapoint' + str (idata + 1 )
280
+ targets .append (np .array (output [dataname ]
281
+ ['atomictargets' ], dtype = float ))
198
282
199
283
return targets
200
284
@@ -214,10 +298,10 @@ def forces(self):
214
298
215
299
tmp1 = []
216
300
217
- if self ._targettype == 'atomic' :
301
+ if self ._natomictargets > 0 :
218
302
msg = "Error while extracting forces from fnetout file '" \
219
303
+ self ._fname + \
220
- "'. Forces only supplied for global property targets."
304
+ "'. Forces supplied for global property targets only ."
221
305
raise FnetoutError (msg )
222
306
223
307
with h5py .File (self ._fname , 'r' ) as fnetoutfile :
0 commit comments