|
| 1 | +# Copyright 2022 The JaxLinOp Contributors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +from typing import Any, Union |
| 19 | + |
| 20 | +import jax.numpy as jnp |
| 21 | +from jaxtyping import Array, Float |
| 22 | +from simple_pytree import static_field |
| 23 | +from dataclasses import dataclass |
| 24 | + |
| 25 | +from .linear_operator import LinearOperator |
| 26 | +from .diagonal_linear_operator import DiagonalLinearOperator |
| 27 | + |
| 28 | + |
| 29 | +def _check_args(value: Any, size: Any) -> None: |
| 30 | + |
| 31 | + if not isinstance(size, int): |
| 32 | + raise ValueError(f"`length` must be an integer, but `length = {size}`.") |
| 33 | + |
| 34 | + if value.ndim != 1: |
| 35 | + raise ValueError( |
| 36 | + f"`value` must be one dimensional scalar, but `value.shape = {value.shape}`." |
| 37 | + ) |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class ConstantDiagonalLinearOperator(DiagonalLinearOperator): |
| 42 | + value: Float[Array, "1"] |
| 43 | + size: int = static_field() |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, value: Float[Array, "1"], size: int, dtype: jnp.dtype = None |
| 47 | + ) -> None: |
| 48 | + """Initialize the constant diagonal linear operator. |
| 49 | +
|
| 50 | + Args: |
| 51 | + value (Float[Array, "1"]): Constant value of the diagonal. |
| 52 | + size (int): Size of the diagonal. |
| 53 | + """ |
| 54 | + |
| 55 | + _check_args(value, size) |
| 56 | + |
| 57 | + if dtype is not None: |
| 58 | + value = value.astype(dtype) |
| 59 | + |
| 60 | + self.value = value |
| 61 | + self.size = size |
| 62 | + self.shape = (size, size) |
| 63 | + self.dtype = value.dtype |
| 64 | + |
| 65 | + def __add__( |
| 66 | + self, other: Union[Float[Array, "N N"], LinearOperator] |
| 67 | + ) -> DiagonalLinearOperator: |
| 68 | + if isinstance(other, ConstantDiagonalLinearOperator): |
| 69 | + if other.size == self.size: |
| 70 | + return ConstantDiagonalLinearOperator( |
| 71 | + value=self.value + other.value, size=self.size |
| 72 | + ) |
| 73 | + |
| 74 | + raise ValueError( |
| 75 | + f"`length` must be the same, but `length = {self.size}` and `length = {other.size}`." |
| 76 | + ) |
| 77 | + |
| 78 | + else: |
| 79 | + return super().__add__(other) |
| 80 | + |
| 81 | + def __mul__(self, other: float) -> LinearOperator: |
| 82 | + """Multiply covariance operator by scalar. |
| 83 | +
|
| 84 | + Args: |
| 85 | + other (LinearOperator): Scalar. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + LinearOperator: Covariance operator multiplied by a scalar. |
| 89 | + """ |
| 90 | + |
| 91 | + return ConstantDiagonalLinearOperator(value=self.value * other, size=self.size) |
| 92 | + |
| 93 | + def _add_diagonal(self, other: DiagonalLinearOperator) -> LinearOperator: |
| 94 | + """Add diagonal to the covariance operator, useful for computing, Kxx + Iσ². |
| 95 | +
|
| 96 | + Args: |
| 97 | + other (DiagonalLinearOperator): Diagonal covariance operator to add to the covariance operator. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + LinearOperator: Covariance operator with the diagonal added. |
| 101 | + """ |
| 102 | + |
| 103 | + if isinstance(other, ConstantDiagonalLinearOperator): |
| 104 | + if other.size == self.size: |
| 105 | + return ConstantDiagonalLinearOperator( |
| 106 | + value=self.value + other.value, size=self.size |
| 107 | + ) |
| 108 | + |
| 109 | + raise ValueError( |
| 110 | + f"`length` must be the same, but `length = {self.size}` and `length = {other.size}`." |
| 111 | + ) |
| 112 | + |
| 113 | + else: |
| 114 | + return super()._add_diagonal(other) |
| 115 | + |
| 116 | + def diagonal(self) -> Float[Array, "N"]: |
| 117 | + """Diagonal of the covariance operator.""" |
| 118 | + return self.value * jnp.ones(self.size) |
| 119 | + |
| 120 | + def to_root(self) -> ConstantDiagonalLinearOperator: |
| 121 | + """ |
| 122 | + Lower triangular. |
| 123 | +
|
| 124 | + Returns: |
| 125 | + Float[Array, "N N"]: Lower triangular matrix. |
| 126 | + """ |
| 127 | + return ConstantDiagonalLinearOperator( |
| 128 | + value=jnp.sqrt(self.value), size=self.size |
| 129 | + ) |
| 130 | + |
| 131 | + def log_det(self) -> Float[Array, "1"]: |
| 132 | + """Log determinant. |
| 133 | +
|
| 134 | + Returns: |
| 135 | + Float[Array, "1"]: Log determinant of the covariance matrix. |
| 136 | + """ |
| 137 | + return 2.0 * self.size * jnp.log(self.value) |
| 138 | + |
| 139 | + def inverse(self) -> ConstantDiagonalLinearOperator: |
| 140 | + """Inverse of the covariance operator. |
| 141 | +
|
| 142 | + Returns: |
| 143 | + DiagonalLinearOperator: Inverse of the covariance operator. |
| 144 | + """ |
| 145 | + return ConstantDiagonalLinearOperator(value=1.0 / self.value, size=self.size) |
| 146 | + |
| 147 | + def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: |
| 148 | + """Solve linear system. |
| 149 | +
|
| 150 | + Args: |
| 151 | + rhs (Float[Array, "N M"]): Right hand side of the linear system. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + Float[Array, "N M"]: Solution of the linear system. |
| 155 | + """ |
| 156 | + |
| 157 | + return rhs / self.value |
| 158 | + |
| 159 | + @classmethod |
| 160 | + def from_dense(cls, dense: Float[Array, "N N"]) -> ConstantDiagonalLinearOperator: |
| 161 | + """Construct covariance operator from dense matrix. |
| 162 | +
|
| 163 | + Args: |
| 164 | + dense (Float[Array, "N N"]): Dense matrix. |
| 165 | +
|
| 166 | + Returns: |
| 167 | + DiagonalLinearOperator: Covariance operator. |
| 168 | + """ |
| 169 | + return ConstantDiagonalLinearOperator( |
| 170 | + value=jnp.atleast_1d(dense[0, 0]), size=dense.shape[0] |
| 171 | + ) |
| 172 | + |
| 173 | + @classmethod |
| 174 | + def from_root( |
| 175 | + cls, root: ConstantDiagonalLinearOperator |
| 176 | + ) -> ConstantDiagonalLinearOperator: |
| 177 | + """Construct covariance operator from root. |
| 178 | +
|
| 179 | + Args: |
| 180 | + root (ConstantDiagonalLinearOperator): Root of the covariance operator. |
| 181 | +
|
| 182 | + Returns: |
| 183 | + ConstantDiagonalLinearOperator: Covariance operator. |
| 184 | + """ |
| 185 | + return ConstantDiagonalLinearOperator(value=root.value**2, size=root.size) |
| 186 | + |
| 187 | + |
| 188 | +__all__ = [ |
| 189 | + "ConstantDiagonalLinearOperator", |
| 190 | +] |
0 commit comments