Skip to content

Commit 7b2f85f

Browse files
authored
Allow cupy inputs in algorithms "newton", "gradient_descen" and "proximal_grad" (#87)
1 parent 69a947e commit 7b2f85f

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

Diff for: dask_glm/algorithms.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def gradient_descent(X, y, max_iter=100, tol=1e-14, family=Logistic, **kwargs):
9797
stepSize = 1.0
9898
recalcRate = 10
9999
backtrackMult = firstBacktrackMult
100-
beta = np.zeros(p)
100+
beta = np.zeros_like(X._meta, shape=p)
101101

102102
for k in range(max_iter):
103103
# how necessary is this recalculation?
@@ -161,7 +161,7 @@ def newton(X, y, max_iter=50, tol=1e-8, family=Logistic, **kwargs):
161161
"""
162162
gradient, hessian = family.gradient, family.hessian
163163
n, p = X.shape
164-
beta = np.zeros(p) # always init to zeros?
164+
beta = np.zeros_like(X._meta, shape=p)
165165
Xbeta = dot(X, beta)
166166

167167
iter_count = 0
@@ -387,7 +387,7 @@ def proximal_grad(X, y, regularizer='l1', lamduh=0.1, family=Logistic,
387387
stepSize = 1.0
388388
recalcRate = 10
389389
backtrackMult = firstBacktrackMult
390-
beta = np.zeros(p)
390+
beta = np.zeros_like(X._meta, shape=p)
391391
regularizer = Regularizer.get(regularizer)
392392

393393
for k in range(max_iter):
@@ -406,8 +406,8 @@ def proximal_grad(X, y, regularizer='l1', lamduh=0.1, family=Logistic,
406406
# Compute the step size
407407
lf = func
408408
for ii in range(100):
409-
beta = regularizer.proximal_operator(obeta - stepSize * gradient, stepSize * lamduh)
410-
step = obeta - beta
409+
beta = regularizer.proximal_operator(- stepSize * gradient + obeta, stepSize * lamduh)
410+
step = - beta + obeta
411411
Xbeta = X.dot(beta)
412412

413413
Xbeta, beta = persist(Xbeta, beta)

Diff for: dask_glm/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def normalize_inputs(X, y, *args, **kwargs):
2323
raise ValueError('Multiple constant columns detected!')
2424
mean[intercept_idx] = 0
2525
std[intercept_idx] = 1
26-
mean = mean if len(intercept_idx[0]) else np.zeros(mean.shape)
26+
mean = mean if len(intercept_idx[0]) else np.zeros_like(X._meta, shape=mean.shape)
2727
Xn = (X - mean) / std
2828
out = algo(Xn, y, *args, **kwargs).copy()
2929
i_adj = np.sum(out * mean / std)
@@ -140,7 +140,7 @@ def add_intercept(X):
140140
raise NotImplementedError("Can not add intercept to array with "
141141
"unknown chunk shape")
142142
j, k = X.chunks
143-
o = da.ones((X.shape[0], 1), chunks=(j, 1))
143+
o = da.ones_like(X, shape=(X.shape[0], 1), chunks=(j, 1))
144144
if is_dask_array_sparse(X):
145145
o = o.map_blocks(sparse.COO)
146146
# TODO: Needed this `.rechunk` for the solver to work

0 commit comments

Comments
 (0)