@@ -111,275 +111,6 @@ void AdaptiveMomentEstimation::set_maximum_time(const type new_maximum_time)
111111 maximum_time = new_maximum_time;
112112}
113113
114- /*
115- TrainingResults AdaptiveMomentEstimation::train()
116- {
117- if(!loss || !loss->has_neural_network() || !loss->has_dataset())
118- return TrainingResults();
119-
120- TrainingResults results(maximum_epochs + 1);
121-
122- check();
123-
124- if(display) cout << "Training with adaptive moment estimation \"Adam\" ..." << endl;
125-
126- // Dataset
127-
128- Dataset* dataset = loss->get_dataset();
129-
130- if(!dataset)
131- throw runtime_error("Dataset is null.");
132-
133- const bool has_validation = dataset->has_validation();
134-
135- const bool is_text_classification_model = is_instance_of<CrossEntropyError3d>(loss);
136-
137- const vector<Index> input_feature_indices = dataset->get_feature_indices("Input");
138- const vector<Index> target_feature_indices = dataset->get_feature_indices("Target");
139- const vector<Index> decoder_feature_indices = dataset->get_feature_indices("Decoder");
140-
141- const vector<Index> training_sample_indices = dataset->get_sample_indices("Training");
142- const vector<Index> validation_sample_indices = dataset->get_sample_indices("Validation");
143-
144- const Index training_samples_number = dataset->get_samples_number("Training");
145- const Index validation_samples_number = dataset->get_samples_number("Validation");
146-
147- const Index training_batch_size = min(training_samples_number, batch_size);
148-
149- const Index validation_batch_size = (validation_samples_number != 0)
150- ? min(validation_samples_number, batch_size)
151- : 0;
152-
153- const Index training_batches_number = (training_batch_size != 0)
154- ? training_samples_number / training_batch_size
155- : 0;
156-
157- const Index validation_batches_number = (validation_batch_size != 0)
158- ? validation_samples_number / validation_batch_size
159- : 0;
160-
161- vector<vector<Index>> training_batches(training_batches_number);
162- vector<vector<Index>> validation_batches(validation_batches_number);
163-
164- // Neural network
165-
166- NeuralNetwork* neural_network = loss->get_neural_network();
167-
168- set_names();
169- set_scaling();
170-
171- Batch training_batch(training_batch_size, dataset);
172- unique_ptr<Batch> validation_batch;
173-
174- ForwardPropagation training_forward_propagation(training_batch_size, neural_network);
175- unique_ptr<ForwardPropagation> validation_forward_propagation;
176-
177- // Loss index
178-
179- loss->set_normalization_coefficient();
180-
181- BackPropagation training_back_propagation(training_batch_size, loss);
182- unique_ptr<BackPropagation> validation_back_propagation;
183-
184- if (has_validation)
185- {
186- validation_batch = make_unique<Batch>(validation_batch_size, dataset);
187- validation_forward_propagation = make_unique<ForwardPropagation>(validation_batch_size, neural_network);
188- validation_back_propagation = make_unique<BackPropagation>(validation_batch_size, loss);
189- }
190-
191- type training_error = type(0);
192- type training_accuracy = type(0);
193-
194- type validation_error = type(0);
195- type validation_accuracy = type(0);
196-
197- Index validation_failures = 0;
198-
199- // Optimization algorithm
200-
201- AdaptiveMomentEstimationData optimization_data(this);
202-
203- bool stop_training = false;
204- bool is_training = true;
205-
206- time_t beginning_time;
207- time(&beginning_time);
208-
209- type elapsed_time = type(0);
210-
211- bool shuffle = true;
212-
213- if(neural_network->has("Recurrent"))
214- shuffle = false;
215-
216- // Main loop
217- optimization_data.iteration = 1;
218-
219- for(Index epoch = 0; epoch <= maximum_epochs; epoch++)
220- {
221- if(display && epoch%display_period == 0) cout << "Epoch: " << epoch << endl;
222-
223- training_batches = dataset->get_batches(training_sample_indices, training_batch_size, shuffle);
224-
225- training_error = type(0);
226-
227- if(is_text_classification_model) training_accuracy = type(0);
228-
229- for(Index iteration = 0; iteration < training_batches_number; iteration++)
230- {
231- training_back_propagation.neural_network.gradient.setZero();
232-
233- // Dataset
234-
235- training_batch.fill(training_batches[iteration],
236- input_feature_indices,
237- decoder_feature_indices,
238- target_feature_indices);
239-
240- // Neural network
241-
242- neural_network->forward_propagate(training_batch.get_inputs(),
243- training_forward_propagation,
244- is_training);
245-
246- // Loss index
247-
248- loss->back_propagate(training_batch,
249- training_forward_propagation,
250- training_back_propagation);
251-
252- training_error += training_back_propagation.error;
253-
254- if(is_text_classification_model) training_accuracy += training_back_propagation.accuracy(0);
255-
256- update_parameters(training_back_propagation, optimization_data);
257- }
258-
259- // Loss
260-
261- training_error /= type(training_batches_number);
262- if(is_text_classification_model)
263- training_accuracy /= type(training_batches_number);
264-
265- results.training_error_history(epoch) = training_error;
266-
267- if(has_validation)
268- {
269- validation_batches = dataset->get_batches(validation_sample_indices, validation_batch_size, shuffle);
270-
271- validation_error = type(0);
272-
273- if(is_text_classification_model)
274- validation_accuracy = type(0);
275-
276- for(Index iteration = 0; iteration < validation_batches_number; iteration++)
277- {
278- // Dataset
279-
280- validation_batch->fill(validation_batches[iteration],
281- input_feature_indices,
282- decoder_feature_indices,
283- target_feature_indices);
284-
285- // Neural network
286-
287- neural_network->forward_propagate(validation_batch->get_inputs(),
288- *validation_forward_propagation,
289- is_training);
290-
291- // Loss
292-
293- loss->calculate_error(*validation_batch,
294- *validation_forward_propagation,
295- *validation_back_propagation);
296-
297- validation_error += validation_back_propagation->error;
298-
299- if(is_text_classification_model)
300- validation_accuracy += validation_back_propagation->accuracy(0);
301- }
302-
303- validation_error /= type(validation_batches_number);
304- if(is_text_classification_model) validation_accuracy /= type(validation_batches_number);
305-
306- results.validation_error_history(epoch) = validation_error;
307-
308- if(epoch != 0 && results.validation_error_history(epoch) > results.validation_error_history(epoch-1)) validation_failures++;
309- }
310-
311- // Elapsed time
312-
313- elapsed_time = get_elapsed_time(beginning_time);
314-
315- if(display && epoch%display_period == 0)
316- {
317- cout << "Training error: " << training_error << endl;
318- if(is_text_classification_model) cout << "Training accuracy: " << training_accuracy << endl;
319- if(has_validation) cout << "Validation error: " << validation_error << endl;
320- if(has_validation && is_text_classification_model) cout << "Validation accuracy: " << validation_accuracy << endl;
321- cout << "Elapsed time: " << write_time(elapsed_time) << endl;
322- }
323-
324- stop_training = true;
325-
326- if(epoch == maximum_epochs)
327- {
328- if(display) cout << "Epoch " << epoch << "\nMaximum epochs number reached: " << epoch << endl;
329- results.stopping_condition = StoppingCondition::MaximumEpochsNumber;
330- }
331- else if(elapsed_time >= maximum_time)
332- {
333- if(display) cout << "Epoch " << epoch << "\nMaximum training time reached: " << write_time(elapsed_time) << endl;
334- results.stopping_condition = StoppingCondition::MaximumTime;
335- }
336- else if(results.training_error_history(epoch) < training_loss_goal)
337- {
338- results.stopping_condition = StoppingCondition::LossGoal;
339- if(display) cout << "Epoch " << epoch << "\nLoss goal reached: " << results.training_error_history(epoch) << endl;
340- }
341- else if(training_accuracy >= training_accuracy_goal)
342- {
343- results.stopping_condition = StoppingCondition::LossGoal;
344- if(display) cout << "Epoch " << epoch << "\nAccuracy goal reached: " << training_accuracy << endl;
345- }
346- else if(validation_failures >= maximum_validation_failures)
347- {
348- if(display) cout << "Epoch " << epoch << "\nMaximum selection failures reached: " << validation_failures << endl;
349- results.stopping_condition = StoppingCondition::MaximumSelectionErrorIncreases;
350- }
351- else
352- {
353- stop_training = false;
354- }
355-
356- if(stop_training)
357- {
358- results.loss = training_back_propagation.loss_value;
359-
360- results.validation_failures = validation_failures;
361-
362- results.resize_training_error_history(epoch+1);
363-
364- results.resize_validation_error_history(has_validation ? epoch + 1 : 0);
365-
366- results.elapsed_time = write_time(elapsed_time);
367-
368- break;
369- }
370-
371- if(epoch != 0 && epoch % save_period == 0) neural_network->save(neural_network_file_name);
372- }
373-
374- set_unscaling();
375-
376- if(display) results.print();
377-
378- return results;
379- }
380-
381- */
382-
383114
384115TrainingResults AdaptiveMomentEstimation::train ()
385116{
0 commit comments