From 031e0a90c5dba8e3a9e4ef9691fad0c2cedcd754 Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Mon, 27 Aug 2018 16:49:18 -0700 Subject: [PATCH] FIX: check label dtypes --- pyglmnet/pyglmnet.py | 4 +++- pyglmnet/utils.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pyglmnet/pyglmnet.py b/pyglmnet/pyglmnet.py index 4eac175a..c058f1c1 100644 --- a/pyglmnet/pyglmnet.py +++ b/pyglmnet/pyglmnet.py @@ -5,7 +5,7 @@ import numpy as np from scipy.special import expit from scipy.stats import norm -from .utils import logger, set_log_level +from .utils import logger, set_log_level, _check_type from .base import BaseEstimator, is_classifier, check_version @@ -603,6 +603,8 @@ def fit(self, X, y): self : instance of GLM The fitted model. """ + _check_type(self.distr, ['boolean', 'probit'], np.bool, y) + _check_type(self.distr, ['poisson'], np.int, y) np.random.RandomState(self.random_state) # checks for group diff --git a/pyglmnet/utils.py b/pyglmnet/utils.py index 14e618f0..cec23b95 100644 --- a/pyglmnet/utils.py +++ b/pyglmnet/utils.py @@ -90,6 +90,13 @@ def tikhonov_from_prior(prior_cov, n_samples, threshold=0.0001): return Tau +def _check_type(this_distr, distr, allowed_type, y): + if this_distr in distr: + if not isinstance(y, allowed_type): + raise ValueError('Expected y to be of type %s. Got %s' + % (allowed_type, type(y))) + + def set_log_level(verbose): """Convenience function for setting the log level.