Skip to content

Commit 1b1ea74

Browse files
committed
🐛 Fix xgboost error
1 parent 4fe8224 commit 1b1ea74

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ ml = [
9393
"psycopg2-binary>=2.9.5",
9494
"gunicorn>=20.1.0",
9595
"lightgbm>=3.3.3",
96-
"xgboost>=1.7.1",
96+
"xgboost<=1.7.1",
9797
"catboost",
9898
"shap>=0.41.0",
9999
"lime>=0.2.0.1",

tests/classification/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def inference(self):
3838

3939
for idx, batch in enumerate(tqdm(self.dataloader)):
4040
img_names = batch["img_names"]
41-
outputs = self.model.get_prediction(batch)
41+
outputs = self.model.predict_step(batch)
4242
preds = outputs["names"]
4343
probs = outputs["confidences"]
4444

tests/semantic/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def inference(self):
4949
img_names = batch["img_names"]
5050
ori_sizes = batch["ori_sizes"]
5151

52-
outputs = self.model.get_prediction(batch)
52+
outputs = self.model.predict_step(batch)
5353
preds = outputs["masks"]
5454

5555
for (inpt, pred, filename, ori_size) in zip(

tests/tabular/test_tablr.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ def test_train_tblr(override_config):
1010
train_pipeline.fit()
1111

1212

13-
# @pytest.mark.order(2)
14-
# def test_eval_tblr(override_config):
15-
# override_config["global"]["pretrained"] = "runs/pytest_tablr/checkpoints/last"
16-
# val_pipeline = MLPipeline(override_config)
17-
# val_pipeline.evaluate()
13+
@pytest.mark.order(2)
14+
def test_eval_tblr(override_config):
15+
override_config["global"]["pretrained"] = "runs/pytest_tablr/checkpoints/last"
16+
val_pipeline = MLPipeline(override_config)
17+
val_pipeline.evaluate()
1818

1919

2020
# @pytest.mark.order(2)

theseus/base/models/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def validation_step(self, batch, batch_idx):
9494
self.log_dict(outputs["loss_dict"], prog_bar=True, on_step=True, on_epoch=False)
9595
return outputs
9696

97-
def predict_step(self, batch, batch_idx):
97+
def predict_step(self, batch, batch_idx=None):
9898
pred = self.model.get_prediction(batch)
9999
return pred
100100

0 commit comments

Comments
 (0)