Open
Description
I was testing the following go snippet in which I just train and test a knn model, and then load the same model and run inference again on the test data and the results are not the same. The confusion matrices of the model1 and 2 are just different.
to test the code you can get "iris_headers.csv" with
wget https://raw.githubusercontent.com/sjwhitworth/golearn/master/examples/datasets/iris_headers.csv
Saved ad loaded model should give the exact same result, how can I fix this behavior?
package main
import (
"fmt"
"github.com/sjwhitworth/golearn/base" // did go get github.com/sjwhitworth/golearn/[email protected]
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/knn"
)
func main() {
save_path := "saved_knn_9"
fmt.Println("Load our csv data")
rawData, err := base.ParseCSVToInstances("iris_headers.csv", true)
if err != nil {
panic(err)
}
fmt.Println("Initialize a new KNN classifier")
cls := knn.NewKnnClassifier("euclidean", "linear", 2)
cls.AllowOptimisations = false
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.60)
cls.Fit(trainData)
predictions, err := cls.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println("Print our summary metrics")
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(confusionMat))
// saving the model to a file
cls.Save(save_path)
// trying to load the model again.
another_cls, err := knn.ReloadKNNClassifier(save_path)
cls.AllowOptimisations = false
if err != nil {
panic(err)
}
fmt.Println("##### USING LOADED MODEL")
predictions2, err := another_cls.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println("Print our summary metrics of loaded model")
confusionMat2, err := evaluation.GetConfusionMatrix(testData, predictions2)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(confusionMat2))
}
This is the terminal output I got for one of my runs:
(base) asdads@ASDF-02Y611MJHC8-2 serve_first_mlmodel % go run main.go
Load our csv data
Initialize a new KNN classifier
Print our summary metrics
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
--------------- -------------- --------------- -------------- --------- ------ --------
Iris-virginica 31 1 66 0.9688 0.9688 0.9688
Iris-setosa 35 0 64 1.0000 1.0000 1.0000
Iris-versicolor 31 1 66 0.9688 0.9688 0.9688
Overall accuracy: 0.9798
writer: &{0xc00012c630 0xc0000145f8 0xc00019c380 <nil> }##### USING LOADED MODEL
Print our summary metrics of loaded model
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
--------------- -------------- --------------- -------------- --------- ------ --------
Iris-setosa 35 0 64 1.0000 1.0000 1.0000
Iris-versicolor 29 1 66 0.9667 0.9062 0.9355
Iris-virginica 31 3 64 0.9118 0.9688 0.9394
Overall accuracy: 0.9596
Metadata
Assignees
Labels
No labels
Activity