-
-
Notifications
You must be signed in to change notification settings - Fork 46
[WIP] Allow pure numpy array (not dask array) as inputs #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,9 +44,14 @@ def test_pr_init(solver): | |
|
||
|
||
@pytest.mark.parametrize('fit_intercept', [True, False]) | ||
@pytest.mark.parametrize('is_sparse', [True, False]) | ||
def test_fit(fit_intercept, is_sparse): | ||
@pytest.mark.parametrize('is_sparse,is_numpy', [ | ||
(True, False), | ||
(False, False), | ||
(False, True)]) | ||
def test_fit(fit_intercept, is_sparse, is_numpy): | ||
X, y = make_classification(n_samples=100, n_features=5, chunksize=10, is_sparse=is_sparse) | ||
if is_numpy: | ||
X, y = dask.compute(X, y) | ||
lr = LogisticRegression(fit_intercept=fit_intercept) | ||
lr.fit(X, y) | ||
lr.predict(X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think I understand this test. When is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's exactly what I tried to do, where both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that's a feature we need to support explicitly, I believe anybody using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! I'll prioritize #89 then. |
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -23,7 +23,7 @@ def normalize_inputs(X, y, *args, **kwargs): | |||||||||||||
raise ValueError('Multiple constant columns detected!') | ||||||||||||||
mean[intercept_idx] = 0 | ||||||||||||||
std[intercept_idx] = 1 | ||||||||||||||
mean = mean if len(intercept_idx[0]) else np.zeros_like(X._meta, shape=mean.shape) | ||||||||||||||
mean = mean if len(intercept_idx[0]) else safe_zeros_like(X, shape=mean.shape) | ||||||||||||||
Xn = (X - mean) / std | ||||||||||||||
out = algo(Xn, y, *args, **kwargs).copy() | ||||||||||||||
i_adj = np.sum(out * mean / std) | ||||||||||||||
|
@@ -41,7 +41,7 @@ def sigmoid(x): | |||||||||||||
|
||||||||||||||
@dispatch(object) | ||||||||||||||
def exp(A): | ||||||||||||||
return A.exp() | ||||||||||||||
return np.exp(A) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@dispatch(float) | ||||||||||||||
|
@@ -91,7 +91,7 @@ def sign(A): | |||||||||||||
|
||||||||||||||
@dispatch(object) | ||||||||||||||
def log1p(A): | ||||||||||||||
return A.log1p() | ||||||||||||||
return np.log1p(A) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@dispatch(np.ndarray) | ||||||||||||||
|
@@ -121,7 +121,7 @@ def is_dask_array_sparse(X): | |||||||||||||
""" | ||||||||||||||
Check using _meta if a dask array contains sparse arrays | ||||||||||||||
""" | ||||||||||||||
return isinstance(X._meta, sparse.SparseArray) | ||||||||||||||
return isinstance(X, da.Array) and isinstance(X._meta, sparse.SparseArray) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@dispatch(np.ndarray) | ||||||||||||||
|
@@ -149,6 +149,11 @@ def add_intercept(X): | |||||||||||||
return X_i | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@dispatch(object) | ||||||||||||||
def add_intercept(X): | ||||||||||||||
return np.concatenate([X, np.ones_like(X, shape=(X.shape[0], 1))], axis=1) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also needs |
||||||||||||||
|
||||||||||||||
|
||||||||||||||
def make_y(X, beta=np.array([1.5, -3]), chunks=2): | ||||||||||||||
n, p = X.shape | ||||||||||||||
z0 = X.dot(beta) | ||||||||||||||
|
@@ -205,3 +210,9 @@ def get_distributed_client(): | |||||||||||||
return get_client() | ||||||||||||||
except ValueError: | ||||||||||||||
return None | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def safe_zeros_like(X, shape): | ||||||||||||||
if isinstance(X, da.Array): | ||||||||||||||
return np.zeros_like(X._meta, shape=shape) | ||||||||||||||
return np.zeros_like(X, shape=shape) | ||||||||||||||
Comment on lines
+216
to
+218
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll also need to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply, I think I might misunderstand our other conversion. #89 (comment) This PR intends to enable dask-glm/dask_glm/algorithms.py Lines 100 to 101 in 7b2f85f
Let's say the input Let me know if this clears things up. Thank you! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's exactly what
That isn't necessarily true, it will only return a Dask array if the reference array is a Dask array. Because we're getting the underlying chunk type with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aha, that works! I will make the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is
safe_zeros_like
coming from? I suppose you wanted tofrom dask.array.utils import zeros_like_safe
instead, from https://github.com/dask/dask/blob/48a4d4a5c5769f6b78881adeb1b3973a950e5f43/dask/array/utils.py#L350