Skip to content

Commit 55d8bbb

Browse files
authored
Merge pull request #52 from ua-datalab/create_pytests_uspantek
"added pytests for uspantek quantum 1 x2. gives key error. have captured that in pytest. "
2 parents 140e458 + dae0c76 commit 55d8bbb

File tree

2 files changed

+211
-102
lines changed

2 files changed

+211
-102
lines changed

classify.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,11 @@ def generate_OOV_parameterising_model(trained_qnlp_model, train_vocab_embeddings
407407
hist = OOV_NN_model.fit(NN_train_X, np.array(NN_train_Y), validation_split=0.2, verbose=1, epochs=epochs_model3_oov_model,callbacks=[callback])
408408
print(hist.history.keys())
409409
print(f'OOV NN model final epoch loss: {(hist.history["loss"][-1], hist.history["val_loss"][-1])}')
410-
plt.plot(hist.history['loss'], label='loss')
411-
plt.plot(hist.history['val_loss'], label='val_loss')
412-
plt.xlabel('Epoch')
413-
plt.ylabel('Error')
414-
plt.legend()
410+
# plt.plot(hist.history['loss'], label='loss')
411+
# plt.plot(hist.history['val_loss'], label='val_loss')
412+
# plt.xlabel('Epoch')
413+
# plt.ylabel('Error')
414+
# plt.legend()
415415
# plt.show() #code is expecting user closing the picture manually. commenting this temporarily since that was preventing the smooth run/debugging of code
416416

417417
best_model= OOV_NN_model
@@ -528,13 +528,18 @@ def read_glue_data(dataset_downloaded,split,lines_to_read=0):
528528

529529

530530

531-
def read_data(filename):
531+
def read_data(filename,lines_to_read):
532532
labels, sentences = [], []
533+
line_counter=0
533534
with open(filename) as f:
534535
for line in f:
535536
t = float(line[0])
536537
labels.append([t, 1-t])
537538
sentences.append(line[1:].strip())
539+
line_counter+=1
540+
if (line_counter> lines_to_read):
541+
break
542+
538543
return labels, sentences
539544

540545

@@ -723,24 +728,24 @@ def run_experiment(train_diagrams, train_labels, val_diagrams, val_labels,test_d
723728
"""if there are no OOV words, we dont need the model 2 through model 4.
724729
just use model 1 to evaluate and exit"""
725730
if oov_word_count==0:
726-
import matplotlib.pyplot as plt
731+
727732
import numpy as np
728733

729-
fig1, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharey='row', figsize=(10, 6))
734+
# fig1, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharey='row', figsize=(10, 6))
730735

731-
ax_tl.set_title('Training set')
732-
ax_tr.set_title('Development set')
733-
ax_bl.set_xlabel('Epochs')
734-
ax_br.set_xlabel('Epochs')
735-
ax_bl.set_ylabel('Accuracy')
736-
ax_tl.set_ylabel('Loss')
736+
# ax_tl.set_title('Training set')
737+
# ax_tr.set_title('Development set')
738+
# ax_bl.set_xlabel('Epochs')
739+
# ax_br.set_xlabel('Epochs')
740+
# ax_bl.set_ylabel('Accuracy')
741+
# ax_tl.set_ylabel('Loss')
737742

738-
colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
739-
range_ = np.arange(1, trainer_class_to_use.epochs+1)
740-
ax_tl.plot(range_, trainer_class_to_use.train_epoch_costs, color=next(colours))
741-
ax_bl.plot(range_, trainer_class_to_use.train_eval_results['acc'], color=next(colours))
742-
ax_tr.plot(range_, trainer_class_to_use.val_costs, color=next(colours))
743-
ax_br.plot(range_, trainer_class_to_use.val_eval_results['acc'], color=next(colours))
743+
# colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
744+
# range_ = np.arange(1, trainer_class_to_use.epochs+1)
745+
# ax_tl.plot(range_, trainer_class_to_use.train_epoch_costs, color=next(colours))
746+
# ax_bl.plot(range_, trainer_class_to_use.train_eval_results['acc'], color=next(colours))
747+
# ax_tr.plot(range_, trainer_class_to_use.val_costs, color=next(colours))
748+
# ax_br.plot(range_, trainer_class_to_use.val_eval_results['acc'], color=next(colours))
744749

745750

746751
val_preds = model1_obj.get_diagram_output(val_circuits)
@@ -918,9 +923,9 @@ def perform_task(args):
918923

919924
else:
920925
#read the base data, i.e plain text english.
921-
train_labels, train_data = read_data(os.path.join(args.data_base_folder,TRAIN))
922-
val_labels, val_data = read_data(os.path.join(args.data_base_folder,DEV))
923-
test_labels, test_data = read_data(os.path.join(args.data_base_folder,TEST))
926+
train_labels, train_data = read_data(os.path.join(args.data_base_folder,TRAIN),lines_to_read= args.no_of_training_data_points_to_use)
927+
val_labels, val_data = read_data(os.path.join(args.data_base_folder,DEV),lines_to_read= args.no_of_training_data_points_to_use)
928+
test_labels, test_data = read_data(os.path.join(args.data_base_folder,TEST),lines_to_read= args.no_of_training_data_points_to_use)
924929

925930

926931

0 commit comments

Comments
 (0)