Skip to content

Commit aa3f5d2

Browse files
authored
Jaxlinop merge (#196)
Merge jaxlinop into main
1 parent 3388f16 commit aa3f5d2

38 files changed

+2802
-37
lines changed

gpjax/gaussian_distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515

1616
import jax.numpy as jnp
17-
from jaxlinop import LinearOperator, IdentityLinearOperator
17+
from .linops import LinearOperator, IdentityLinearOperator
1818

1919
from jaxtyping import Array, Float
2020
from jax import vmap

gpjax/gps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from jaxtyping import Array, Float
2222
from jax.random import KeyArray
2323

24-
from jaxlinop import identity
25-
from jaxkern.base import AbstractKernel
24+
from .linops import identity
25+
from .kernels.base import AbstractKernel
2626
from jaxutils import PyTree
2727

2828
from .config import get_global_config

gpjax/kernels/computations/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
from typing import Callable, Dict
1818

1919
from jax import vmap
20-
from jaxlinop import (
20+
from jaxtyping import Array, Float
21+
from jaxutils import PyTree
22+
23+
from ...linops import (
2124
DenseLinearOperator,
2225
DiagonalLinearOperator,
2326
LinearOperator,
2427
)
25-
from jaxtyping import Array, Float
26-
from jaxutils import PyTree
2728

2829

2930
class AbstractKernelComputation(PyTree):

gpjax/kernels/computations/basis_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax.numpy as jnp
44
from jaxtyping import Array, Float
55
from .base import AbstractKernelComputation
6-
from jaxlinop import DenseLinearOperator
6+
from ...linops import DenseLinearOperator
77

88

99
class BasisFunctionComputation(AbstractKernelComputation):

gpjax/kernels/computations/constant_diagonal.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
import jax.numpy as jnp
1818

1919
from jax import vmap
20-
from jaxlinop import (
20+
from jaxtyping import Array, Float
21+
from .base import AbstractKernelComputation
22+
23+
from ...linops import (
2124
ConstantDiagonalLinearOperator,
2225
DiagonalLinearOperator,
2326
)
24-
from jaxtyping import Array, Float
25-
from .base import AbstractKernelComputation
2627

2728

2829
class ConstantDiagonalKernelComputation(AbstractKernelComputation):

gpjax/kernels/computations/diagonal.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,12 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from typing import Callable, Dict
17-
1816
from jax import vmap
19-
from jaxlinop import (
20-
DiagonalLinearOperator,
21-
)
17+
from typing import Callable, Dict
2218
from jaxtyping import Array, Float
23-
from .base import AbstractKernelComputation
2419

20+
from .base import AbstractKernelComputation
21+
from ...linops import DiagonalLinearOperator
2522

2623
class DiagonalKernelComputation(AbstractKernelComputation):
2724
def __init__(

gpjax/likelihoods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import abc
1717
from typing import Any, Callable, Dict, Optional
18-
from jaxlinop.utils import to_dense
18+
from .linops.utils import to_dense
1919
from jaxutils import PyTree
2020

2121
import distrax as dx

gpjax/linops/README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# LinOps
2+
3+
The `linops` submodule is a lightweight linear operator library written in [`jax`](https://github.com/google/jax).
4+
5+
# Overview
6+
Consider solving a diagonal matrix $A$ against a vector $b$.
7+
8+
```python
9+
import jax.numpy as jnp
10+
11+
n = 1000
12+
diag = jnp.linspace(1.0, 2.0, n)
13+
14+
A = jnp.diag(diag)
15+
b = jnp.linspace(3.0, 4.0, n)
16+
17+
# A⁻¹ b
18+
jnp.solve(A, b)
19+
```
20+
Doing so is costly in large problems. Storing the matrix gives rise to memory costs of $O(n^2)$, and inverting the matrix costs $O(n^3)$ in the number of data points $n$.
21+
22+
But hold on a second. Notice:
23+
24+
- We only have to store the diagonal entries to determine the matrix $A$. Doing so, would reduce memory costs from $O(n^2)$ to $O(n)$.
25+
- To invert $A$, we only need to take the reciprocal of the diagonal, reducing inversion costs from $O(n^3)$, to $O(n)$.
26+
27+
`JaxLinOp` is designed to exploit stucture of this kind.
28+
```python
29+
from gpjax import linops
30+
31+
A = linops.DiagonalLinearOperator(diag = diag)
32+
33+
# A⁻¹ b
34+
A.solve(b)
35+
```
36+
`linops` is designed to automatically reduce cost savings in matrix addition, multiplication, computing log-determinants and more, for other matrix stuctures too!
37+
38+
# Custom Linear Operator (details to come soon)
39+
40+
The flexible design of `linops` will allow users to impliment their own custom linear operators.
41+
42+
```python
43+
from gpjax.linops import LinearOperator
44+
45+
class MyLinearOperator(LinearOperator):
46+
47+
def __init__(self, ...)
48+
...
49+
50+
# There will be a minimal number methods that users need to impliment for their custom operator.
51+
# For optimal efficiency, we'll make it easy for the user to add optional methods to their operator,
52+
# if they give better performance than the defaults.
53+
```

gpjax/linops/__init__.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2022 The JaxLinOp Contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from .linear_operator import LinearOperator
17+
from .dense_linear_operator import DenseLinearOperator
18+
from .diagonal_linear_operator import DiagonalLinearOperator
19+
from .constant_diagonal_linear_operator import ConstantDiagonalLinearOperator
20+
from .identity_linear_operator import IdentityLinearOperator
21+
from .zero_linear_operator import ZeroLinearOperator
22+
from .triangular_linear_operator import (
23+
LowerTriangularLinearOperator,
24+
UpperTriangularLinearOperator,
25+
)
26+
from .utils import (
27+
identity,
28+
to_dense,
29+
)
30+
31+
__all__ = [
32+
"LinearOperator",
33+
"DenseLinearOperator",
34+
"DiagonalLinearOperator",
35+
"ConstantDiagonalLinearOperator",
36+
"IdentityLinearOperator",
37+
"ZeroLinearOperator",
38+
"LowerTriangularLinearOperator",
39+
"UpperTriangularLinearOperator",
40+
"identity",
41+
"to_dense",
42+
]
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright 2022 The JaxLinOp Contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import annotations
17+
18+
from typing import Any, Union
19+
20+
import jax.numpy as jnp
21+
from jaxtyping import Array, Float
22+
from simple_pytree import static_field
23+
from dataclasses import dataclass
24+
25+
from .linear_operator import LinearOperator
26+
from .diagonal_linear_operator import DiagonalLinearOperator
27+
28+
29+
def _check_args(value: Any, size: Any) -> None:
30+
31+
if not isinstance(size, int):
32+
raise ValueError(f"`length` must be an integer, but `length = {size}`.")
33+
34+
if value.ndim != 1:
35+
raise ValueError(
36+
f"`value` must be one dimensional scalar, but `value.shape = {value.shape}`."
37+
)
38+
39+
40+
@dataclass
41+
class ConstantDiagonalLinearOperator(DiagonalLinearOperator):
42+
value: Float[Array, "1"]
43+
size: int = static_field()
44+
45+
def __init__(
46+
self, value: Float[Array, "1"], size: int, dtype: jnp.dtype = None
47+
) -> None:
48+
"""Initialize the constant diagonal linear operator.
49+
50+
Args:
51+
value (Float[Array, "1"]): Constant value of the diagonal.
52+
size (int): Size of the diagonal.
53+
"""
54+
55+
_check_args(value, size)
56+
57+
if dtype is not None:
58+
value = value.astype(dtype)
59+
60+
self.value = value
61+
self.size = size
62+
self.shape = (size, size)
63+
self.dtype = value.dtype
64+
65+
def __add__(
66+
self, other: Union[Float[Array, "N N"], LinearOperator]
67+
) -> DiagonalLinearOperator:
68+
if isinstance(other, ConstantDiagonalLinearOperator):
69+
if other.size == self.size:
70+
return ConstantDiagonalLinearOperator(
71+
value=self.value + other.value, size=self.size
72+
)
73+
74+
raise ValueError(
75+
f"`length` must be the same, but `length = {self.size}` and `length = {other.size}`."
76+
)
77+
78+
else:
79+
return super().__add__(other)
80+
81+
def __mul__(self, other: float) -> LinearOperator:
82+
"""Multiply covariance operator by scalar.
83+
84+
Args:
85+
other (LinearOperator): Scalar.
86+
87+
Returns:
88+
LinearOperator: Covariance operator multiplied by a scalar.
89+
"""
90+
91+
return ConstantDiagonalLinearOperator(value=self.value * other, size=self.size)
92+
93+
def _add_diagonal(self, other: DiagonalLinearOperator) -> LinearOperator:
94+
"""Add diagonal to the covariance operator, useful for computing, Kxx + Iσ².
95+
96+
Args:
97+
other (DiagonalLinearOperator): Diagonal covariance operator to add to the covariance operator.
98+
99+
Returns:
100+
LinearOperator: Covariance operator with the diagonal added.
101+
"""
102+
103+
if isinstance(other, ConstantDiagonalLinearOperator):
104+
if other.size == self.size:
105+
return ConstantDiagonalLinearOperator(
106+
value=self.value + other.value, size=self.size
107+
)
108+
109+
raise ValueError(
110+
f"`length` must be the same, but `length = {self.size}` and `length = {other.size}`."
111+
)
112+
113+
else:
114+
return super()._add_diagonal(other)
115+
116+
def diagonal(self) -> Float[Array, "N"]:
117+
"""Diagonal of the covariance operator."""
118+
return self.value * jnp.ones(self.size)
119+
120+
def to_root(self) -> ConstantDiagonalLinearOperator:
121+
"""
122+
Lower triangular.
123+
124+
Returns:
125+
Float[Array, "N N"]: Lower triangular matrix.
126+
"""
127+
return ConstantDiagonalLinearOperator(
128+
value=jnp.sqrt(self.value), size=self.size
129+
)
130+
131+
def log_det(self) -> Float[Array, "1"]:
132+
"""Log determinant.
133+
134+
Returns:
135+
Float[Array, "1"]: Log determinant of the covariance matrix.
136+
"""
137+
return 2.0 * self.size * jnp.log(self.value)
138+
139+
def inverse(self) -> ConstantDiagonalLinearOperator:
140+
"""Inverse of the covariance operator.
141+
142+
Returns:
143+
DiagonalLinearOperator: Inverse of the covariance operator.
144+
"""
145+
return ConstantDiagonalLinearOperator(value=1.0 / self.value, size=self.size)
146+
147+
def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]:
148+
"""Solve linear system.
149+
150+
Args:
151+
rhs (Float[Array, "N M"]): Right hand side of the linear system.
152+
153+
Returns:
154+
Float[Array, "N M"]: Solution of the linear system.
155+
"""
156+
157+
return rhs / self.value
158+
159+
@classmethod
160+
def from_dense(cls, dense: Float[Array, "N N"]) -> ConstantDiagonalLinearOperator:
161+
"""Construct covariance operator from dense matrix.
162+
163+
Args:
164+
dense (Float[Array, "N N"]): Dense matrix.
165+
166+
Returns:
167+
DiagonalLinearOperator: Covariance operator.
168+
"""
169+
return ConstantDiagonalLinearOperator(
170+
value=jnp.atleast_1d(dense[0, 0]), size=dense.shape[0]
171+
)
172+
173+
@classmethod
174+
def from_root(
175+
cls, root: ConstantDiagonalLinearOperator
176+
) -> ConstantDiagonalLinearOperator:
177+
"""Construct covariance operator from root.
178+
179+
Args:
180+
root (ConstantDiagonalLinearOperator): Root of the covariance operator.
181+
182+
Returns:
183+
ConstantDiagonalLinearOperator: Covariance operator.
184+
"""
185+
return ConstantDiagonalLinearOperator(value=root.value**2, size=root.size)
186+
187+
188+
__all__ = [
189+
"ConstantDiagonalLinearOperator",
190+
]

0 commit comments

Comments
 (0)