Skip to content

Commit c40b7cc

Browse files
AronNemeths3alfisc
andauthored
Switch default solver to scipy.linalg.solve() [#846 issue] (#904)
* Solvers: switching to scipy.linalg.solve as default * update test/solvers * Update tests_solvers: new test matrix is symmetric and positive definite -> Cholesky works --------- Co-authored-by: Alexander Fischer <[email protected]>
1 parent cfee1f2 commit c40b7cc

13 files changed

+110
-43
lines changed

pyfixest/estimation/FixestMulti_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,11 @@ def _estimate_all_models(
201201
self,
202202
vcov: Union[str, dict[str, str], None],
203203
solver: Literal[
204-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
204+
"np.linalg.lstsq",
205+
"np.linalg.solve",
206+
"scipy.linalg.solve",
207+
"scipy.sparse.linalg.lsqr",
208+
"jax",
205209
],
206210
demeaner_backend: Literal["numba", "jax"] = "numba",
207211
collin_tol: float = 1e-6,

pyfixest/estimation/estimation.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def feols(
3434
store_data: bool = True,
3535
lean: bool = False,
3636
weights_type: WeightsTypeOptions = "aweights",
37-
solver: SolverOptions = "np.linalg.solve",
37+
solver: SolverOptions = "scipy.linalg.solve",
3838
demeaner_backend: DemeanerBackendOptions = "numba",
3939
use_compression: bool = False,
4040
reps: int = 100,
@@ -118,8 +118,9 @@ def feols(
118118
see this blog post: https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/.
119119
120120
solver : SolverOptions, optional.
121-
The solver to use for the regression. Can be either "np.linalg.solve" or
122-
"np.linalg.lstsq". Defaults to "np.linalg.solve".
121+
The solver to use for the regression. Can be "np.linalg.lstsq",
122+
"np.linalg.solve", "scipy.linalg.solve", "scipy.sparse.linalg.lsqr" and "jax".
123+
Defaults to "scipy.linalg.solve".
123124
124125
demeaner_backend: DemeanerBackendOptions, optional
125126
The backend to use for demeaning. Can be either "numba" or "jax". Defaults to "numba".
@@ -510,7 +511,7 @@ def fepois(
510511
iwls_maxiter: int = 25,
511512
collin_tol: float = 1e-10,
512513
separation_check: Optional[list[str]] = None,
513-
solver: SolverOptions = "np.linalg.solve",
514+
solver: SolverOptions = "scipy.linalg.solve",
514515
demeaner_backend: DemeanerBackendOptions = "numba",
515516
drop_intercept: bool = False,
516517
i_ref1=None,
@@ -569,8 +570,9 @@ def fepois(
569570
Either "fe" or "ir". Executes "fe" by default (when None).
570571
571572
solver : SolverOptions, optional.
572-
The solver to use for the regression. Can be either "np.linalg.solve" or
573-
"np.linalg.lstsq". Defaults to "np.linalg.solve".
573+
The solver to use for the regression. Can be "np.linalg.lstsq",
574+
"np.linalg.solve", "scipy.linalg.solve", "scipy.sparse.linalg.lsqr" and "jax".
575+
Defaults to "scipy.linalg.solve".
574576
575577
demeaner_backend: DemeanerBackendOptions, optional
576578
The backend to use for demeaning. Can be either "numba" or "jax".
@@ -737,7 +739,7 @@ def feglm(
737739
iwls_maxiter: int = 25,
738740
collin_tol: float = 1e-10,
739741
separation_check: Optional[list[str]] = None,
740-
solver: SolverOptions = "np.linalg.solve",
742+
solver: SolverOptions = "scipy.linalg.solve",
741743
drop_intercept: bool = False,
742744
i_ref1=None,
743745
copy_data: bool = True,
@@ -799,8 +801,9 @@ def feglm(
799801
Either "fe" or "ir". Executes "fe" by default (when None).
800802
801803
solver : SolverOptions, optional.
802-
The solver to use for the regression. Can be either "np.linalg.solve" or
803-
"np.linalg.lstsq". Defaults to "np.linalg.solve".
804+
The solver to use for the regression. Can be "np.linalg.lstsq",
805+
"np.linalg.solve", "scipy.linalg.solve", "scipy.sparse.linalg.lsqr" and "jax".
806+
Defaults to "scipy.linalg.solve".
804807
805808
drop_intercept : bool, optional
806809
Whether to drop the intercept from the model, by default False.

pyfixest/estimation/fegaussian_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def __init__(
2626
tol: float,
2727
maxiter: int,
2828
solver: Literal[
29-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
29+
"np.linalg.lstsq",
30+
"np.linalg.solve",
31+
"scipy.linalg.solve",
32+
"scipy.sparse.linalg.lsqr",
33+
"jax",
3034
],
3135
store_data: bool = True,
3236
copy_data: bool = True,

pyfixest/estimation/feglm_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def __init__(
3333
tol: float,
3434
maxiter: int,
3535
solver: Literal[
36-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
36+
"np.linalg.lstsq",
37+
"np.linalg.solve",
38+
"scipy.linalg.solve",
39+
"scipy.sparse.linalg.lsqr",
40+
"jax",
3741
],
3842
store_data: bool = True,
3943
copy_data: bool = True,

pyfixest/estimation/feiv_.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ class Feiv(Feols):
4040
Names of the coefficients of Z.
4141
collin_tol : float
4242
Tolerance for collinearity check.
43-
solver: Literal["np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"],
44-
default is 'np.linalg.solve'. Solver to use for the estimation.
43+
solver: Literal["np.linalg.lstsq", "np.linalg.solve", "scipy.linalg.solve",
44+
"scipy.sparse.linalg.lsqr", "jax"],
45+
default is "scipy.linalg.solve". Solver to use for the estimation.
4546
demeaner_backend: Literal["numba", "jax"]
4647
The backend used for demeaning.
4748
weights_name : Optional[str]
@@ -144,8 +145,12 @@ def __init__(
144145
fixef_tol: float,
145146
lookup_demeaned_data: dict[str, pd.DataFrame],
146147
solver: Literal[
147-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
148-
] = "np.linalg.solve",
148+
"np.linalg.lstsq",
149+
"np.linalg.solve",
150+
"scipy.linalg.solve",
151+
"scipy.sparse.linalg.lsqr",
152+
"jax",
153+
] = "scipy.linalg.solve",
149154
demeaner_backend: Literal["numba", "jax"] = "numba",
150155
store_data: bool = True,
151156
copy_data: bool = True,

pyfixest/estimation/felogit_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def __init__(
2626
tol: float,
2727
maxiter: int,
2828
solver: Literal[
29-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
29+
"np.linalg.lstsq",
30+
"np.linalg.solve",
31+
"scipy.linalg.solve",
32+
"scipy.sparse.linalg.lsqr",
33+
"jax",
3034
],
3135
store_data: bool = True,
3236
copy_data: bool = True,

pyfixest/estimation/feols_.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ class Feols:
9191
Type of the weights variable. Either "aweights" for analytic weights or
9292
"fweights" for frequency weights.
9393
solver : str, optional.
94-
The solver to use for the regression. Can be either "np.linalg.solve" or
95-
"np.linalg.lstsq". Defaults to "np.linalg.solve".
94+
The solver to use for the regression. Can be "np.linalg.lstsq",
95+
"np.linalg.solve", "scipy.linalg.solve", "scipy.sparse.linalg.lsqr" and "jax".
96+
Defaults to "scipy.linalg.solve".
9697
context : int or Mapping[str, Any]
9798
A dictionary containing additional context variables to be used by
9899
formulaic during the creation of the model matrix. This can include
@@ -204,8 +205,9 @@ class Feols:
204205
Adjusted R-squared value of the model.
205206
_adj_r2_within : float
206207
Adjusted R-squared value computed on demeaned dependent variable.
207-
_solver: Literal["np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"],
208-
default is 'np.linalg.solve'. Solver to use for the estimation.
208+
_solver: Literal["np.linalg.lstsq", "np.linalg.solve", "scipy.linalg.solve",
209+
"scipy.sparse.linalg.lsqr", "jax"],
210+
default is "scipy.linalg.solve". Solver to use for the estimation.
209211
_demeaner_backend: Literal["numba", "jax"]
210212
The backend used for demeaning.
211213
_data: pd.DataFrame
@@ -234,8 +236,12 @@ def __init__(
234236
fixef_tol: float,
235237
lookup_demeaned_data: dict[str, pd.DataFrame],
236238
solver: Literal[
237-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
238-
] = "np.linalg.solve",
239+
"np.linalg.lstsq",
240+
"np.linalg.solve",
241+
"scipy.linalg.solve",
242+
"scipy.sparse.linalg.lsqr",
243+
"jax",
244+
] = "scipy.linalg.solve",
239245
demeaner_backend: Literal["numba", "jax"] = "numba",
240246
store_data: bool = True,
241247
copy_data: bool = True,

pyfixest/estimation/feols_compressed_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def __init__(
8686
fixef_tol: float,
8787
lookup_demeaned_data: dict[str, pd.DataFrame],
8888
solver: Literal[
89-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
89+
"np.linalg.lstsq",
90+
"np.linalg.solve",
91+
"scipy.linalg.solve",
92+
"scipy.sparse.linalg.lsqr",
93+
"jax",
9094
],
9195
demeaner_backend: Literal["numba", "jax"] = "numba",
9296
store_data: bool = True,

pyfixest/estimation/fepois_.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ class Fepois(Feols):
5151
Maximum number of iterations for the IRLS algorithm.
5252
tol : Optional[float], default=1e-08
5353
Tolerance level for the convergence of the IRLS algorithm.
54-
solver: Literal["np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"],
55-
default is 'np.linalg.solve'. Solver to use for the estimation.
54+
solver : str, optional.
55+
The solver to use for the regression. Can be "np.linalg.lstsq",
56+
"np.linalg.solve", "scipy.linalg.solve", "scipy.sparse.linalg.lsqr" and "jax".
57+
Defaults to "scipy.linalg.solve".
5658
demeaner_backend: Literal["numba", "jax"]
5759
The backend used for demeaning.
5860
fixef_tol: float, default = 1e-08.
@@ -86,8 +88,12 @@ def __init__(
8688
tol: float,
8789
maxiter: int,
8890
solver: Literal[
89-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
90-
] = "np.linalg.solve",
91+
"np.linalg.lstsq",
92+
"np.linalg.solve",
93+
"scipy.linalg.solve",
94+
"scipy.sparse.linalg.lsqr",
95+
"jax",
96+
] = "scipy.linalg.solve",
9197
demeaner_backend: Literal["numba", "jax"] = "numba",
9298
context: Union[int, Mapping[str, Any]] = 0,
9399
store_data: bool = True,

pyfixest/estimation/feprobit_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ def __init__(
2828
tol: float,
2929
maxiter: int,
3030
solver: Literal[
31-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
31+
"np.linalg.lstsq",
32+
"np.linalg.solve",
33+
"scipy.linalg.solve",
34+
"scipy.sparse.linalg.lsqr",
35+
"jax",
3236
],
3337
store_data: bool = True,
3438
copy_data: bool = True,

pyfixest/estimation/literals.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
WeightsTypeOptions = Literal["aweights", "fweights"]
66
FixedRmOptions = Literal["singleton", "none"]
77
SolverOptions = Literal[
8-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
8+
"np.linalg.lstsq",
9+
"np.linalg.solve",
10+
"scipy.linalg.solve",
11+
"scipy.sparse.linalg.lsqr",
12+
"jax",
913
]
1014
DemeanerBackendOptions = Literal["numba", "jax"]
1115
PredictionErrorOptions = Literal["prediction"]

pyfixest/estimation/solvers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from scipy.linalg import solve
23
from scipy.sparse.linalg import lsqr
34
from typing_extensions import Literal
45

@@ -7,7 +8,11 @@ def solve_ols(
78
tZX: np.ndarray,
89
tZY: np.ndarray,
910
solver: Literal[
10-
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
11+
"np.linalg.lstsq",
12+
"np.linalg.solve",
13+
"scipy.linalg.solve",
14+
"scipy.sparse.linalg.lsqr",
15+
"jax",
1116
],
1217
) -> np.ndarray:
1318
"""
@@ -17,8 +22,8 @@ def solve_ols(
1722
----------
1823
tZX (array-like): Z'X.
1924
tZY (array-like): Z'Y.
20-
solver (str): The solver to use. Supported solvers are"np.linalg.lstsq",
21-
"np.linalg.solve", "scipy.sparse.linalg.lsqr" and "jax".
25+
solver (str): The solver to use. Supported solvers are "np.linalg.lstsq",
26+
"np.linalg.solve", "scipy.linalg.solve", "scipy.sparse.linalg.lsqr" and "jax".
2227
2328
Returns
2429
-------
@@ -32,6 +37,8 @@ def solve_ols(
3237
return np.linalg.lstsq(tZX, tZY, rcond=None)[0].flatten()
3338
elif solver == "np.linalg.solve":
3439
return np.linalg.solve(tZX, tZY).flatten()
40+
elif solver == "scipy.linalg.solve":
41+
return solve(tZX, tZY, assume_a="pos").flatten()
3542
elif solver == "scipy.sparse.linalg.lsqr":
3643
return lsqr(tZX, tZY)[0].flatten()
3744
elif solver == "jax":

tests/test_solvers.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
def test_solve_ols_simple_2x2():
88
# Test case 1: Simple 2x2 system
9-
tZX = np.array([[1, 2], [3, 4]])
10-
tZY = np.array([5, 6])
11-
solver = "np.linalg.lstsq"
9+
tZX = np.array([[4, 2], [2, 3]])
10+
tZY = np.array([10, 8])
11+
solver = "scipy.linalg.solve"
1212
solution = solve_ols(tZX, tZY, solver)
13-
assert np.allclose(solution, np.array([-4.0, 4.5]))
13+
assert np.allclose(solution, np.array([1.75, 1.5]))
1414
# Verify solution satisfies the system
1515
assert np.allclose(tZX @ solution, tZY)
1616

@@ -19,21 +19,33 @@ def test_solve_ols_identity():
1919
# Test case 2: Identity matrix
2020
tZX = np.eye(2)
2121
tZY = np.array([1, 2])
22-
solver = "np.linalg.lstsq"
22+
solver = "scipy.linalg.solve"
2323
assert np.allclose(solve_ols(tZX, tZY, solver), tZY)
2424

2525

2626
@pytest.mark.parametrize(
2727
argnames="solver",
28-
argvalues=["np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"],
29-
ids=["np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"],
28+
argvalues=[
29+
"scipy.linalg.solve",
30+
"np.linalg.lstsq",
31+
"np.linalg.solve",
32+
"scipy.sparse.linalg.lsqr",
33+
"jax",
34+
],
35+
ids=[
36+
"scipy.linalg.solve",
37+
"np.linalg.lstsq",
38+
"np.linalg.solve",
39+
"scipy.sparse.linalg.lsqr",
40+
"jax",
41+
],
3042
)
3143
def test_solve_ols_different_solvers(solver):
3244
# Test case 3: Test different solvers give same result
33-
tZX = np.array([[1, 2], [3, 4]])
34-
tZY = np.array([5, 6])
45+
tZX = np.array([[4, 2], [2, 3]])
46+
tZY = np.array([10, 8])
3547
solution = solve_ols(tZX, tZY, solver)
36-
assert np.allclose(solution, np.array([-4.0, 4.5]))
48+
assert np.allclose(solution, np.array([1.75, 1.5]))
3749
# Verify solution satisfies the system
3850
assert np.allclose(tZX @ solution, tZY)
3951

0 commit comments

Comments
 (0)