Skip to content

KNN Classifier saved and loaded models don't give the same results. #285

Open
@juanma9613

Description

@juanma9613

I was testing the following go snippet in which I just train and test a knn model, and then load the same model and run inference again on the test data and the results are not the same. The confusion matrices of the model1 and 2 are just different.

to test the code you can get "iris_headers.csv" with
wget https://raw.githubusercontent.com/sjwhitworth/golearn/master/examples/datasets/iris_headers.csv

Saved ad loaded model should give the exact same result, how can I fix this behavior?

package main

import (
	"fmt"

	"github.com/sjwhitworth/golearn/base" // did go get github.com/sjwhitworth/golearn/[email protected]
	"github.com/sjwhitworth/golearn/evaluation"
	"github.com/sjwhitworth/golearn/knn"
)

func main() {
	save_path := "saved_knn_9"

	fmt.Println("Load our csv data")
	rawData, err := base.ParseCSVToInstances("iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	fmt.Println("Initialize a new KNN classifier")

	cls := knn.NewKnnClassifier("euclidean", "linear", 2)
	cls.AllowOptimisations = false

	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.60)
	cls.Fit(trainData)

	predictions, err := cls.Predict(testData)
	if err != nil {
		panic(err)
	}

	fmt.Println("Print our summary metrics")
	confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(confusionMat))

	// saving the model to a file
	cls.Save(save_path)
	// trying to load the model again.
	another_cls, err := knn.ReloadKNNClassifier(save_path)
	cls.AllowOptimisations = false

	if err != nil {
		panic(err)
	}

	fmt.Println("##### USING LOADED MODEL")

	predictions2, err := another_cls.Predict(testData)
	if err != nil {
		panic(err)
	}

	fmt.Println("Print our summary metrics of loaded model")
	confusionMat2, err := evaluation.GetConfusionMatrix(testData, predictions2)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(confusionMat2))

}

This is the terminal output I got for one of my runs:

(base) asdads@ASDF-02Y611MJHC8-2 serve_first_mlmodel % go run main.go
Load our csv data
Initialize a new KNN classifier
Print our summary metrics
Reference Class True Positives  False Positives True Negatives  Precision       Recall  F1 Score
--------------- --------------  --------------- --------------  ---------       ------  --------
Iris-virginica  31              1               66              0.9688          0.9688  0.9688
Iris-setosa     35              0               64              1.0000          1.0000  1.0000
Iris-versicolor 31              1               66              0.9688          0.9688  0.9688
Overall accuracy: 0.9798

writer: &{0xc00012c630 0xc0000145f8 0xc00019c380 <nil> }##### USING LOADED MODEL
Print our summary metrics of loaded model
Reference Class True Positives  False Positives True Negatives  Precision       Recall  F1 Score
--------------- --------------  --------------- --------------  ---------       ------  --------
Iris-setosa     35              0               64              1.0000          1.0000  1.0000
Iris-versicolor 29              1               66              0.9667          0.9062  0.9355
Iris-virginica  31              3               64              0.9118          0.9688  0.9394
Overall accuracy: 0.9596

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions