Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions scripts/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import json
import click
from collections import Counter
from sklearn.metrics import classification_report
from sklearn.externals import joblib
from sklearn.metrics import f1_score, recall_score, precision_score
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pkspace.utils.loaders import PKSpaceLoader, PKLotLoader # noqa


@click.command()
@click.option('--loader', '-l', type=click.Choice(['PKLot', 'PKSpace']),
default='PKSpace', help='Loader used to load dataset')
@click.argument('dataset_dir',
type=click.Path(exists=True, file_okay=False, dir_okay=True,
resolve_path=True))
@click.argument('model_file',
type=click.Path(exists=True, file_okay=True, dir_okay=False,
resolve_path=True))
@click.option("--machine_friendly", '-f', is_flag=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rstevanak I don't really think the --machine_friendly is the right name to use here. I think it would be much better if it would be --output or --format and user could possibly choose from multiple formats like yaml/json for instance.

help='prints serialized dictionary of results')
def test_model(loader, dataset_dir, model_file, machine_friendly):
if loader == 'PKSpace':
loader = PKSpaceLoader()
elif loader == 'PKLot':
loader = PKLotLoader()

model = joblib.load(model_file)
spaces, ground_answers = loader.load(dataset_dir)
model_answers = model.predict(spaces)
if machine_friendly:
answer = {'avg': {}, 0: {}, 1: {}}
metrics = [precision_score, recall_score, f1_score]
classes_counter = Counter(ground_answers)

for i in [0, 1]:
for func in metrics:
score = func(ground_answers, model_answers, pos_label=i)
answer[i][func.__name__] = score
class_support = classes_counter[i]

# summing total support
answer[i]['support'] = class_support
old_sum_support = answer['avg'].get('support', 0)
answer['avg']['support'] = old_sum_support + class_support

# calculating weighted average for all functions
for column in [x.__name__ for x in metrics]:
col_sum = 0
for ans_class in [0, 1]:
row = answer[ans_class]
col_sum += row[column] * row['support']
answer['avg'][column] = col_sum / answer['avg']['support']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rstevanak How about we would use http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html. I think it would be nicer. It should return the things we need as was also mentioned here: https://stackoverflow.com/a/42467096

print(json.dumps(answer))

else:
print(classification_report(ground_answers, model_answers))


if __name__ == '__main__':
test_model()