@@ -407,11 +407,11 @@ def generate_OOV_parameterising_model(trained_qnlp_model, train_vocab_embeddings
407407 hist = OOV_NN_model .fit (NN_train_X , np .array (NN_train_Y ), validation_split = 0.2 , verbose = 1 , epochs = epochs_model3_oov_model ,callbacks = [callback ])
408408 print (hist .history .keys ())
409409 print (f'OOV NN model final epoch loss: { (hist .history ["loss" ][- 1 ], hist .history ["val_loss" ][- 1 ])} ' )
410- plt .plot (hist .history ['loss' ], label = 'loss' )
411- plt .plot (hist .history ['val_loss' ], label = 'val_loss' )
412- plt .xlabel ('Epoch' )
413- plt .ylabel ('Error' )
414- plt .legend ()
410+ # plt.plot(hist.history['loss'], label='loss')
411+ # plt.plot(hist.history['val_loss'], label='val_loss')
412+ # plt.xlabel('Epoch')
413+ # plt.ylabel('Error')
414+ # plt.legend()
415415 # plt.show() #code is expecting user closing the picture manually. commenting this temporarily since that was preventing the smooth run/debugging of code
416416
417417 best_model = OOV_NN_model
@@ -528,13 +528,18 @@ def read_glue_data(dataset_downloaded,split,lines_to_read=0):
528528
529529
530530
531- def read_data (filename ):
531+ def read_data (filename , lines_to_read ):
532532 labels , sentences = [], []
533+ line_counter = 0
533534 with open (filename ) as f :
534535 for line in f :
535536 t = float (line [0 ])
536537 labels .append ([t , 1 - t ])
537538 sentences .append (line [1 :].strip ())
539+ line_counter += 1
540+ if (line_counter > lines_to_read ):
541+ break
542+
538543 return labels , sentences
539544
540545
@@ -723,24 +728,24 @@ def run_experiment(train_diagrams, train_labels, val_diagrams, val_labels,test_d
723728 """if there are no OOV words, we dont need the model 2 through model 4.
724729 just use model 1 to evaluate and exit"""
725730 if oov_word_count == 0 :
726- import matplotlib . pyplot as plt
731+
727732 import numpy as np
728733
729- fig1 , ((ax_tl , ax_tr ), (ax_bl , ax_br )) = plt .subplots (2 , 2 , sharey = 'row' , figsize = (10 , 6 ))
734+ # fig1, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharey='row', figsize=(10, 6))
730735
731- ax_tl .set_title ('Training set' )
732- ax_tr .set_title ('Development set' )
733- ax_bl .set_xlabel ('Epochs' )
734- ax_br .set_xlabel ('Epochs' )
735- ax_bl .set_ylabel ('Accuracy' )
736- ax_tl .set_ylabel ('Loss' )
736+ # ax_tl.set_title('Training set')
737+ # ax_tr.set_title('Development set')
738+ # ax_bl.set_xlabel('Epochs')
739+ # ax_br.set_xlabel('Epochs')
740+ # ax_bl.set_ylabel('Accuracy')
741+ # ax_tl.set_ylabel('Loss')
737742
738- colours = iter (plt .rcParams ['axes.prop_cycle' ].by_key ()['color' ])
739- range_ = np .arange (1 , trainer_class_to_use .epochs + 1 )
740- ax_tl .plot (range_ , trainer_class_to_use .train_epoch_costs , color = next (colours ))
741- ax_bl .plot (range_ , trainer_class_to_use .train_eval_results ['acc' ], color = next (colours ))
742- ax_tr .plot (range_ , trainer_class_to_use .val_costs , color = next (colours ))
743- ax_br .plot (range_ , trainer_class_to_use .val_eval_results ['acc' ], color = next (colours ))
743+ # colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
744+ # range_ = np.arange(1, trainer_class_to_use.epochs+1)
745+ # ax_tl.plot(range_, trainer_class_to_use.train_epoch_costs, color=next(colours))
746+ # ax_bl.plot(range_, trainer_class_to_use.train_eval_results['acc'], color=next(colours))
747+ # ax_tr.plot(range_, trainer_class_to_use.val_costs, color=next(colours))
748+ # ax_br.plot(range_, trainer_class_to_use.val_eval_results['acc'], color=next(colours))
744749
745750
746751 val_preds = model1_obj .get_diagram_output (val_circuits )
@@ -918,9 +923,9 @@ def perform_task(args):
918923
919924 else :
920925 #read the base data, i.e plain text english.
921- train_labels , train_data = read_data (os .path .join (args .data_base_folder ,TRAIN ))
922- val_labels , val_data = read_data (os .path .join (args .data_base_folder ,DEV ))
923- test_labels , test_data = read_data (os .path .join (args .data_base_folder ,TEST ))
926+ train_labels , train_data = read_data (os .path .join (args .data_base_folder ,TRAIN ), lines_to_read = args . no_of_training_data_points_to_use )
927+ val_labels , val_data = read_data (os .path .join (args .data_base_folder ,DEV ), lines_to_read = args . no_of_training_data_points_to_use )
928+ test_labels , test_data = read_data (os .path .join (args .data_base_folder ,TEST ), lines_to_read = args . no_of_training_data_points_to_use )
924929
925930
926931
0 commit comments