Skip to content

Prediction shape differs between tl2cgen and treelite_runtime #48

@rachaelesler

Description

@rachaelesler

The output value of tl2cgen's Predictor.predict() function is different from that of treelite_runtime even when using equivalent code and models.

In an example case:

  • Using tl2cgen version 1.0.0, treelite 4.4.1, the shape of the output prediction was (500, 1, 1)
  • Using treelite_runtime version 3.4.0, treelite 3.4.0, the shape of the output prediction was (500,)

I would expect the shape of the output predictions to be the same in both versions, or for it to be mentioned in the documentation if this has changed from treelite_runtime.

In testing, I found the actual values (and hence the RMSE) to be the same - just the shape of the prediction was different.

I have thoroughly searched forums and the issues for treelite and tl2cgen and I could not find anything mentioning this possible bug, so I am raising an issue here.

I discovered this when performing a package migration for a project.

Steps to Reproduce

The following sections contain the conda environments, scripts, and outputs I used to test this issue.

The regression folder referenced in the below scripts is an example from the LightGBM repository.

Prediction using treelite_runtime

Conda environment:

channels:
  - conda-forge
dependencies:
  - lightgbm
  - numpy=2.2
  - pandas
  - python
  - treelite=3.4.0
  - ipykernel
  - scikit-learn

Python script:

from pathlib import Path

import pandas as pd

import lightgbm as lgb

from sklearn.metrics import mean_squared_error
import treelite_runtime
import treelite
import pathlib
import numpy as np

# load or create your dataset
regression_example_dir = Path("../regression")
df_train = pd.read_csv(
    str(regression_example_dir / "regression.train"), header=None, sep="\t"
)
df_test = pd.read_csv(
    str(regression_example_dir / "regression.test"), header=None, sep="\t"
)

y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)

# specify your configurations as a dict
params = {
    "boosting_type": "gbdt",
    "objective": "regression",
    "metric": {"l2", "l1"},
    "num_leaves": 31,
    "learning_rate": 0.05,
    "feature_fraction": 0.9,
    "bagging_fraction": 0.8,
    "bagging_freq": 5,
    "verbose": 0,
}


print(f"LightGBM version: {lgb.__version__}")
print(f"Using treelite_runtime version {treelite_runtime.__version__}")
print(f"Using treelite version {treelite.__version__}")

# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)


print("Starting training...")
# train
gbm = lgb.train(
    params,
    lgb_train,
    num_boost_round=20,
    valid_sets=lgb_eval,
    callbacks=[lgb.early_stopping(stopping_rounds=5)],
)

print("Saving lightgbm model...")
# save model to file
model_path = pathlib.Path("mid_models/model.txt")
model_path.parent.mkdir(parents=True, exist_ok=True)
gbm.save_model(str(model_path))

# Compile with treelite
print("Saving treelite model...")
model = treelite.Model.from_lightgbm(gbm)
export_model_path = pathlib.Path("./mid_models/model.so")
export_model_path.parent.mkdir(parents=True, exist_ok=True)
model.export_lib(
    toolchain="gcc",
    libpath=str(export_model_path),
    verbose=False,
    params={"parallel_comp": 8, "quantize": 1},
)


print(f"LightGBM version: {lgb.__version__}")
print(f"Using treelite_runtime version {treelite_runtime.__version__}")
print(f"Using treelite version {treelite.__version__}")

print("Starting predicting...")
# predict
# Use treelite to predict
treelite_predictor = treelite_runtime.Predictor(libpath="./mid_models")
dmat = treelite_runtime.DMatrix(np.asarray(X_test))

prediction = treelite_predictor.predict(dmat)

print(f"The shape of prediction is: {prediction.shape}")
print(f"The type of prediction is: {type(prediction)}")

print("\nHere are the first few rows of the prediction:")
print(prediction[:10])

# evaluate
rmse_test = mean_squared_error(y_test, prediction.flatten()) ** 0.5
print(f"\nThe RMSE of prediction is: {rmse_test}")

Output:

LightGBM version: 4.6.0
Using treelite_runtime version 3.4.0
Using treelite version 3.4.0

The shape of prediction is: (500,)

Here are the first few rows of the prediction:
[0.62621477 0.5082804  0.37533158 0.46197493 0.37773395 0.35045096
 0.41781655 0.40869795 0.68264813 0.4702126 ]

The RMSE of prediction is: 0.4450426449744025

Prediction using tl2cgen

Conda environment:

channels:
  - conda-forge
dependencies:
  - lightgbm=4.6.0
  - numpy=2.2
  - pandas
  - python=3.12
  - tl2cgen>=1.0.0
  - treelite>=4.0.0
  - ipykernel
  - scikit-learn

Python script:

from pathlib import Path

import pandas as pd

import numpy as np

import lightgbm as lgb

from sklearn.metrics import mean_squared_error
import tl2cgen
import treelite
import pathlib


# load or create your dataset
regression_example_dir = Path("../regression")
df_train = pd.read_csv(
    str(regression_example_dir / "regression.train"), header=None, sep="\t"
)
df_test = pd.read_csv(
    str(regression_example_dir / "regression.test"), header=None, sep="\t"
)

y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)

# specify your configurations as a dict
params = {
    "boosting_type": "gbdt",
    "objective": "regression",
    "metric": {"l2", "l1"},
    "num_leaves": 31,
    "learning_rate": 0.05,
    "feature_fraction": 0.9,
    "bagging_fraction": 0.8,
    "bagging_freq": 5,
    "verbose": 0,
}

print(f"LightGBM version: {lgb.__version__}")
print(f"Using tl2cgen version {tl2cgen.__version__}")
print(f"Using treelite version {treelite.__version__}")

# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)


print("Starting training...")
# train
gbm = lgb.train(
    params,
    lgb_train,
    num_boost_round=20,
    valid_sets=lgb_eval,
    callbacks=[lgb.early_stopping(stopping_rounds=5)],
)

# Save model to file
print("Saving lightgbm model...")
model_path = pathlib.Path("new_models/model.txt")
model_path.parent.mkdir(parents=True, exist_ok=True)
gbm.save_model(model_path)

# Compile with treelite
print("Saving treelite model...")
model = treelite.frontend.from_lightgbm(gbm)
export_model_path = pathlib.Path("./new_models/model.so")
export_model_path.parent.mkdir(parents=True, exist_ok=True)
tl2cgen.export_lib(
    model,
    toolchain="gcc",
    libpath=str(export_model_path),
    verbose=False,
    params={"parallel_comp": 8, "quantize": 1},
)


print("Starting predicting...")
# predict
# Use treelite to predict
treelite_predictor = tl2cgen.Predictor(libpath="./new_models")
dmat = tl2cgen.DMatrix(np.asarray(X_test))

prediction = treelite_predictor.predict(dmat)
print(f"The shape of prediction is: {prediction.shape}")
print(f"The type of prediction is: {type(prediction)}")

print("\n\nHere are the first few rows of the prediction:")
print(prediction[:10])

# eval
rmse_test = mean_squared_error(y_test, prediction.flatten()) ** 0.5
print(f"\nThe RMSE of prediction is: {rmse_test}")

Output:

LightGBM version: 4.6.0
Using tl2cgen version 1.0.0
Using treelite version 4.4.1

The shape of prediction is: (500, 1, 1)

Here are the first few rows of the prediction:
[[[0.62621477]]

 [[0.5082804 ]]

 [[0.37533158]]

 [[0.46197493]]

 [[0.37773395]]

 [[0.35045096]]

 [[0.41781655]]

 [[0.40869795]]

 [[0.68264813]]

 [[0.4702126 ]]]

The RMSE of prediction is: 0.4450426449744025

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions