Skip to content

KNN Classifier saved and loaded models don't give the same results. #285

Open
@juanma9613

Description

I was testing the following go snippet in which I just train and test a knn model, and then load the same model and run inference again on the test data and the results are not the same. The confusion matrices of the model1 and 2 are just different.

to test the code you can get "iris_headers.csv" with
wget https://raw.githubusercontent.com/sjwhitworth/golearn/master/examples/datasets/iris_headers.csv

Saved ad loaded model should give the exact same result, how can I fix this behavior?

package main

import (
	"fmt"

	"github.com/sjwhitworth/golearn/base" // did go get github.com/sjwhitworth/golearn/[email protected]
	"github.com/sjwhitworth/golearn/evaluation"
	"github.com/sjwhitworth/golearn/knn"
)

func main() {
	save_path := "saved_knn_9"

	fmt.Println("Load our csv data")
	rawData, err := base.ParseCSVToInstances("iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	fmt.Println("Initialize a new KNN classifier")

	cls := knn.NewKnnClassifier("euclidean", "linear", 2)
	cls.AllowOptimisations = false

	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.60)
	cls.Fit(trainData)

	predictions, err := cls.Predict(testData)
	if err != nil {
		panic(err)
	}

	fmt.Println("Print our summary metrics")
	confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(confusionMat))

	// saving the model to a file
	cls.Save(save_path)
	// trying to load the model again.
	another_cls, err := knn.ReloadKNNClassifier(save_path)
	cls.AllowOptimisations = false

	if err != nil {
		panic(err)
	}

	fmt.Println("##### USING LOADED MODEL")

	predictions2, err := another_cls.Predict(testData)
	if err != nil {
		panic(err)
	}

	fmt.Println("Print our summary metrics of loaded model")
	confusionMat2, err := evaluation.GetConfusionMatrix(testData, predictions2)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(confusionMat2))

}

This is the terminal output I got for one of my runs:

(base) asdads@ASDF-02Y611MJHC8-2 serve_first_mlmodel % go run main.go
Load our csv data
Initialize a new KNN classifier
Print our summary metrics
Reference Class True Positives  False Positives True Negatives  Precision       Recall  F1 Score
--------------- --------------  --------------- --------------  ---------       ------  --------
Iris-virginica  31              1               66              0.9688          0.9688  0.9688
Iris-setosa     35              0               64              1.0000          1.0000  1.0000
Iris-versicolor 31              1               66              0.9688          0.9688  0.9688
Overall accuracy: 0.9798

writer: &{0xc00012c630 0xc0000145f8 0xc00019c380 <nil> }##### USING LOADED MODEL
Print our summary metrics of loaded model
Reference Class True Positives  False Positives True Negatives  Precision       Recall  F1 Score
--------------- --------------  --------------- --------------  ---------       ------  --------
Iris-setosa     35              0               64              1.0000          1.0000  1.0000
Iris-versicolor 29              1               66              0.9667          0.9062  0.9355
Iris-virginica  31              3               64              0.9118          0.9688  0.9394
Overall accuracy: 0.9596

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions