|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: UTF-8 -*- |
| 3 | +import pytest |
| 4 | + |
| 5 | +def test_train(): |
| 6 | + from DeezyMatch import train as dm_train |
| 7 | + # train a new model |
| 8 | + dm_train(input_file_path="./inputs/input_dfm_pytest_001.yaml", |
| 9 | + dataset_path="./dataset/dataset-string-similarity_test.txt", |
| 10 | + model_name="test001") |
| 11 | + |
| 12 | +def test_finetune(): |
| 13 | + from DeezyMatch import finetune as dm_finetune |
| 14 | + # fine-tune a pretrained model stored at pretrained_model_path and pretrained_vocab_path |
| 15 | + dm_finetune(input_file_path="./inputs/input_dfm_pytest_001.yaml", |
| 16 | + dataset_path="./dataset/dataset-string-similarity_test.txt", |
| 17 | + model_name="finetuned_test001", |
| 18 | + pretrained_model_path="./models/test001/test001.model", |
| 19 | + pretrained_vocab_path="./models/test001/test001.vocab") |
| 20 | + |
| 21 | +def test_inference(): |
| 22 | + from DeezyMatch import inference as dm_inference |
| 23 | + |
| 24 | + # model inference using a model stored at pretrained_model_path and pretrained_vocab_path |
| 25 | + dm_inference(input_file_path="./inputs/input_dfm_pytest_001.yaml", |
| 26 | + dataset_path="./dataset/dataset-string-similarity_test.txt", |
| 27 | + pretrained_model_path="./models/finetuned_test001/finetuned_test001.model", |
| 28 | + pretrained_vocab_path="./models/finetuned_test001/finetuned_test001.vocab") |
| 29 | + |
| 30 | +def test_generate_query_vecs(): |
| 31 | + from DeezyMatch import inference as dm_inference |
| 32 | + |
| 33 | + # generate vectors for queries (specified in dataset_path) |
| 34 | + # using a model stored at pretrained_model_path and pretrained_vocab_path |
| 35 | + dm_inference(input_file_path="./inputs/input_dfm_pytest_001.yaml", |
| 36 | + dataset_path="./dataset/dataset-string-similarity_test.txt", |
| 37 | + pretrained_model_path="./models/finetuned_test001/finetuned_test001.model", |
| 38 | + pretrained_vocab_path="./models/finetuned_test001/finetuned_test001.vocab", |
| 39 | + inference_mode="vect", |
| 40 | + scenario="queries/test") |
| 41 | + |
| 42 | +def test_generate_candidate_vecs(): |
| 43 | + from DeezyMatch import inference as dm_inference |
| 44 | + |
| 45 | + # generate vectors for candidates (specified in dataset_path) |
| 46 | + # using a model stored at pretrained_model_path and pretrained_vocab_path |
| 47 | + dm_inference(input_file_path="./inputs/input_dfm_pytest_001.yaml", |
| 48 | + dataset_path="./dataset/dataset-string-similarity_test.txt", |
| 49 | + pretrained_model_path="./models/finetuned_test001/finetuned_test001.model", |
| 50 | + pretrained_vocab_path="./models/finetuned_test001/finetuned_test001.vocab", |
| 51 | + inference_mode="vect", |
| 52 | + scenario="candidates/test") |
| 53 | + |
| 54 | +def test_assemble_queries(): |
| 55 | + from DeezyMatch import combine_vecs |
| 56 | + |
| 57 | + # combine vectors stored in queries/test and save them in combined/queries_test |
| 58 | + combine_vecs(rnn_passes=['fwd', 'bwd'], |
| 59 | + input_scenario='queries/test', |
| 60 | + output_scenario='combined/queries_test', |
| 61 | + print_every=10) |
| 62 | + |
| 63 | +def test_assemble_candidates(): |
| 64 | + from DeezyMatch import combine_vecs |
| 65 | + |
| 66 | + # combine vectors stored in candidates/test and save them in combined/candidates_test |
| 67 | + combine_vecs(rnn_passes=['fwd', 'bwd'], |
| 68 | + input_scenario='candidates/test', |
| 69 | + output_scenario='combined/candidates_test', |
| 70 | + print_every=10) |
| 71 | + |
| 72 | +def test_candidate_ranker(): |
| 73 | + from DeezyMatch import candidate_ranker |
| 74 | + |
| 75 | + # Select candidates based on L2-norm distance (aka faiss distance): |
| 76 | + # find candidates from candidate_scenario |
| 77 | + # for queries specified in query_scenario |
| 78 | + candidates_pd = \ |
| 79 | + candidate_ranker(query_scenario="./combined/queries_test", |
| 80 | + candidate_scenario="./combined/candidates_test", |
| 81 | + ranking_metric="faiss", |
| 82 | + selection_threshold=5., |
| 83 | + num_candidates=2, |
| 84 | + search_size=10, |
| 85 | + output_path="ranker_results/test_candidates_deezymatch", |
| 86 | + pretrained_model_path="./models/finetuned_test001/finetuned_test001.model", |
| 87 | + pretrained_vocab_path="./models/finetuned_test001/finetuned_test001.vocab", |
| 88 | + number_test_rows=5) |
0 commit comments