Skip to content

Commit 5f9da15

Browse files
authored
Merge branch 'main' into jpb/diagfromjac
2 parents caeb4e8 + b16b03e commit 5f9da15

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

.github/workflows/run_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
3333
- name: Test with pytest
3434
run: |
35-
pytest
35+
python -m tests
3636
3737
- name: Check that documentation can be built.
3838
run: |

lineax/_solver/cg.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from collections.abc import Callable
1617
from typing import Any, TypeAlias
1718

@@ -267,5 +268,17 @@ def assume_full_rank(self):
267268
"""
268269

269270

270-
def NormalCG(*args):
271-
return Normal(CG(*args))
271+
def NormalCG(*args, **kwargs):
272+
"""Deprecated helper function. Use `lx.Normal(lx.CG(...))` instead.
273+
274+
!!! warning "Deprecated"
275+
`NormalCG(...)` is deprecated in favour of `lx.Normal(lx.CG(...))`.
276+
This will be removed in some future version of Lineax.
277+
"""
278+
warnings.warn(
279+
"`NormalCG(...)` is deprecated in favour of `lx.Normal(lx.CG(...))`. "
280+
"This will be removed in some future version of Lineax.",
281+
DeprecationWarning,
282+
stacklevel=2,
283+
)
284+
return Normal(CG(*args, **kwargs))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ name = "lineax"
2828
readme = "README.md"
2929
requires-python = "~=3.10"
3030
urls = {repository = "https://github.com/google/lineax"}
31-
version = "0.0.8"
31+
version = "0.1.0"
3232

3333
[project.optional-dependencies]
3434
dev = [

tests/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype, i):
5959
def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64):
6060
if isinstance(solver, lx.Normal):
6161
cond_cutoff = math.sqrt(1000)
62+
elif isinstance(solver, lx.LSMR):
63+
cond_cutoff = 10 # it's not doing super well for some reason
6264
else:
6365
cond_cutoff = 1000
6466
return tuple(

0 commit comments

Comments
 (0)