@@ -368,6 +368,35 @@ def load_checkpoint(self):
368
368
individual_states [i ][3 * self .num_params :], dtype = float )
369
369
)
370
370
self .particles .append (particle )
371
+
372
+ #restore batches
373
+ self .particles_batch = []
374
+ if (self .num_batch == 1 ):
375
+ self .particles_batch .append (self .particles )
376
+ self .batch_size = len (self .particles )
377
+ else :
378
+ # Calculate the approximate batch size
379
+ self .batch_size = len (self .particles ) // self .num_batch
380
+
381
+ # Check if the division leaves some elements unallocated
382
+ remaining_elements = len (self .particles ) % self .batch_size
383
+
384
+ if remaining_elements > 0 :
385
+ # Warn the user and suggest adjusting the number of particles or batches
386
+ warning_message = (
387
+ f"{ bcolors .WARNING } The specified number of batches ({ self .num_batch } ) does not evenly divide the number of particles ({ len (self .particles )} ).{ bcolors .ENDC } "
388
+ )
389
+ warnings .warn (warning_message )
390
+
391
+ # Use list comprehension to create batches
392
+ self .particles_batch = [self .particles [i :i + self .batch_size ]
393
+ for i in range (0 , len (self .particles ), self .batch_size )]
394
+
395
+ # If the division leaves some elements unallocated, add them to the last batch
396
+ if remaining_elements > 0 :
397
+ last_batch = self .particles_batch .pop ()
398
+ last_batch .extend (self .particles [len (self .particles_batch ) * self .batch_size :])
399
+ self .particles_batch .append (last_batch )
371
400
372
401
# restore pareto front
373
402
self .pareto_front = []
@@ -381,18 +410,15 @@ def load_checkpoint(self):
381
410
best_position = None ,
382
411
best_fitness = None )
383
412
self .pareto_front .append (particle )
413
+
384
414
385
415
def process_batch (self , worker_id , batch ):
386
416
# Launch a program for this batch using objective_function
387
417
print (f"Worker ID { worker_id } " )
388
418
params = [particle .position for particle in batch ]
389
419
optimization_output = self .objective .evaluate (params , worker_id )
390
- for p_id , output in enumerate (optimization_output [0 ]):
391
- particle = batch [p_id ]
392
- if self .optimization_mode == 'individual' :
393
- particle .evaluate_fitness (self .objective_functions )
394
- if self .optimization_mode == 'global' :
395
- particle .set_fitness (output )
420
+ for p_id , particle in enumerate (batch ):
421
+ particle .set_fitness (optimization_output [:,p_id ])
396
422
batch [p_id ] = particle
397
423
return batch
398
424
@@ -413,23 +439,24 @@ def optimize(self, num_iterations = 100, max_iter_no_improv = None):
413
439
for _ in range (num_iterations ):
414
440
with ProcessPoolExecutor (max_workers = self .num_batch ) as executor :
415
441
futures = [executor .submit (self .process_batch , worker_id , batch )
416
- for worker_id , batch in enumerate (self .particles_batch )]
442
+ for worker_id , batch in enumerate (self .particles_batch )]
417
443
418
444
new_batches = []
419
445
for future in futures :
420
446
batch = future .result ()
421
447
new_batches .append (batch )
422
448
self .particles_batch = new_batches
423
449
save_particles = []
450
+ updated_particles = []
424
451
for batch in self .particles_batch :
425
452
for particle in batch :
426
453
l = np .concatenate ([particle .position , np .ravel (particle .fitness )])
427
- print (l )
428
454
save_particles .append (l )
455
+ updated_particles .append (particle )
429
456
FileManager .save_csv (save_particles ,
430
457
'history/iteration' + str (self .iteration ) + '.csv' )
431
458
432
-
459
+ self . particles = updated_particles
433
460
self .update_pareto_front ()
434
461
435
462
for batch in self .particles_batch :
0 commit comments