Skip to content

Commit 3631a90

Browse files
committed
Merge remote-tracking branch 'upstream/main' into hadamardlikelihood
2 parents 66b0cb6 + b017b9c commit 3631a90

File tree

83 files changed

+2437
-1005
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+2437
-1005
lines changed

.conda/meta.yaml

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
{% set data = load_setup_py_data(setup_file="../setup.py", from_recipe_dir=True) %}
1+
{% set _version_match = load_file_regex(
2+
load_file="gpytorch/version.py",
3+
regex_pattern="__version__ = version = '(.+)'"
4+
) %}
5+
{% set version = _version_match[1] %}
26

37
package:
4-
name: {{ data.get("name")|lower }}
5-
version: {{ data.get("version") }}
8+
name: gpytorch
9+
version: {{ version }}
610

711
source:
812
path: ../
@@ -13,13 +17,15 @@ build:
1317

1418
requirements:
1519
host:
16-
- python>=3.8
20+
- python>=3.10
1721

1822
run:
19-
- python>=3.8
20-
- pytorch>=1.11
23+
- python>=3.10
24+
- jaxtyping
25+
- linear_operator>=0.6
26+
- mpmath>=0.19,<=1.3
27+
- pytorch>=2.0
2128
- scikit-learn
22-
- linear_operator>=0.5.2
2329

2430
test:
2531
imports:

.github/ISSUE_TEMPLATE/---documentation-examples.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ assignees: ''
2121
** Think you know how to fix the docs? ** (If so, we'd love a pull request from you!)
2222

2323
- Link to [GPyTorch documentation](https://gpytorch.readthedocs.io)
24-
- Link to [GPyTorch examples](https://github.com/cornellius-gp/gpytorch/tree/master/examples)
24+
- Link to [GPyTorch examples](https://github.com/cornellius-gp/gpytorch/tree/main/examples)

.github/workflows/deploy.yml

+4-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- name: Set up Python
2020
uses: actions/setup-python@v2
2121
with:
22-
python-version: "3.8"
22+
python-version: "3.10"
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
@@ -41,7 +41,7 @@ jobs:
4141
- uses: conda-incubator/setup-miniconda@v2
4242
with:
4343
auto-update-conda: false
44-
python-version: "3.8"
44+
python-version: "3.10"
4545
- name: Install dependencies
4646
run: |
4747
conda install -y anaconda-client conda-build
@@ -52,9 +52,8 @@ jobs:
5252
conda config --set anaconda_upload yes
5353
conda config --append channels pytorch
5454
conda config --append channels gpytorch
55+
conda config --append channels conda-forge
5556
/usr/share/miniconda/bin/anaconda login --username ${{ secrets.CONDA_USERNAME }} --password ${{ secrets.CONDA_PASSWORD }}
5657
python -m setuptools_scm
57-
cd .conda
58-
conda build .
58+
conda build .conda
5959
/usr/share/miniconda/bin/anaconda logout
60-
cd ..

.github/workflows/run_test_suite.yml

+11-9
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ name: Run Test Suite
55

66
on:
77
push:
8-
branches: [ master ]
8+
branches: [ main, develop ]
99
pull_request:
10-
branches: [ master ]
10+
branches: [ main, develop ]
1111
workflow_call:
1212

1313
jobs:
@@ -18,7 +18,7 @@ jobs:
1818
- name: Set up Python
1919
uses: actions/setup-python@v2
2020
with:
21-
python-version: "3.8"
21+
python-version: "3.10"
2222
- name: Install dependencies
2323
run: |
2424
pip install flake8==4.0.1 flake8-print==4.0.0 pre-commit
@@ -37,20 +37,21 @@ jobs:
3737
runs-on: ubuntu-latest
3838
strategy:
3939
matrix:
40-
pytorch-version: ["master", "stable"]
40+
pytorch-version: ["main", "stable"]
4141
extras: ["with-extras", "no-extras"]
4242
steps:
4343
- uses: actions/checkout@v2
4444
- name: Set up Python
4545
uses: actions/setup-python@v2
4646
with:
47-
python-version: "3.8"
47+
python-version: "3.10"
4848
- name: Install dependencies
4949
run: |
50-
if [[ ${{ matrix.pytorch-version }} = "master" ]]; then
50+
if [[ ${{ matrix.pytorch-version }} = "main" ]]; then
5151
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html;
5252
else
53-
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html;
53+
pip install "numpy<2" # Numpy 2.0 is not fully supported until PyTorch 2.2
54+
pip install torch==2.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
5455
fi
5556
pip install -e .
5657
if [[ ${{ matrix.extras }} == "with-extras" ]]; then
@@ -69,10 +70,11 @@ jobs:
6970
- name: Set up Python
7071
uses: actions/setup-python@v2
7172
with:
72-
python-version: "3.8"
73+
python-version: "3.10"
7374
- name: Install dependencies
7475
run: |
75-
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
76+
pip install "numpy<2" # Numpy 2.0 is not fully supported until PyTorch 2.2
77+
pip install torch==2.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
7678
pip install pytest nbval jupyter tqdm matplotlib torchvision scipy
7779
pip install -e .
7880
pip install "pyro-ppl>=1.8";

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,23 @@ repos:
1616
hooks:
1717
- id: flake8
1818
args: [--config=setup.cfg]
19-
exclude: ^(examples/*)|(docs/*)
19+
exclude: ^(examples/.*)|(docs/.*)
2020
- repo: https://github.com/omnilib/ufmt
2121
rev: v2.0.0
2222
hooks:
2323
- id: ufmt
2424
additional_dependencies:
2525
- black == 22.3.0
2626
- usort == 1.0.3
27-
exclude: ^(build/*)|(docs/*)|(examples/*)
27+
exclude: ^(build/.*)|(docs/.*)|(examples/.*)
2828
- repo: https://github.com/jumanjihouse/pre-commit-hooks
2929
rev: 2.1.6
3030
hooks:
3131
- id: require-ascii
32-
exclude: ^(examples/.*\.ipynb)|(.github/ISSUE_TEMPLATE/*)
32+
exclude: ^(examples/.*\.ipynb)|(.github/ISSUE_TEMPLATE/.*)
3333
- id: script-must-have-extension
3434
- id: forbid-binary
35-
exclude: ^(examples/*)|(test/examples/old_variational_strategy_model.pth)
35+
exclude: ^(examples/.*)|(test/examples/old_variational_strategy_model.pth)
3636
- repo: https://github.com/Lucas-C/pre-commit-hooks
3737
rev: v1.1.13
3838
hooks:

.readthedocs.yml

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ version: 2
88
build:
99
os: "ubuntu-22.04"
1010
tools:
11-
python: "3.8"
11+
python: "3.10"
1212
jobs:
13-
pre_install: # Lock version of torch at 1.11
14-
- pip install torch==1.11.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
13+
pre_install: # Lock version of torch at 2.0
14+
- pip install "numpy<2" # Numpy 2.0 is not fully supported until PyTorch 2.2
15+
- pip install torch==2.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
1516
pre_build:
1617
- python -m setuptools_scm # Get correct version number
1718

CONTRIBUTING.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ We use [standard sphinx docstrings](https://sphinx-rtd-tutorial.readthedocs.io/e
3636

3737
### Type Hints
3838

39-
GPyTorch aims to be fully typed using Python 3.8+
39+
GPyTorch aims to be fully typed using Python 3.10+
4040
[type hints](https://www.python.org/dev/peps/pep-0484/).
4141

4242
We recognize that we have a long way to go towards fully typing the library,

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
[![Documentation Status](https://readthedocs.org/projects/gpytorch/badge/?version=latest)](https://gpytorch.readthedocs.io/en/latest/?badge=latest)
66
[![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
77

8-
[![Python Version](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
8+
[![Python Version](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
99
[![Conda](https://img.shields.io/conda/v/gpytorch/gpytorch.svg)](https://anaconda.org/gpytorch/gpytorch)
1010
[![PyPI](https://img.shields.io/pypi/v/gpytorch.svg)](https://pypi.org/project/gpytorch)
1111

@@ -29,8 +29,8 @@ See our [**documentation, examples, tutorials**](https://gpytorch.readthedocs.io
2929
## Installation
3030

3131
**Requirements**:
32-
- Python >= 3.8
33-
- PyTorch >= 1.11
32+
- Python >= 3.10
33+
- PyTorch >= 2.0
3434

3535
Install GPyTorch using pip or conda:
3636

@@ -88,7 +88,7 @@ If you use GPyTorch, please cite the following papers:
8888

8989
## Contributing
9090

91-
See the contributing guidelines [CONTRIBUTING.md](https://github.com/cornellius-gp/gpytorch/blob/master/CONTRIBUTING.md)
91+
See the contributing guidelines [CONTRIBUTING.md](https://github.com/cornellius-gp/gpytorch/blob/main/CONTRIBUTING.md)
9292
for information on submitting issues and pull requests.
9393

9494

docs/source/conf.py

+69-46
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import sys
2020
import sphinx_rtd_theme # noqa
2121
import warnings
22-
from typing import ForwardRef
22+
23+
import jaxtyping
2324

2425

2526
def read(*names, **kwargs):
@@ -112,7 +113,8 @@ def find_version(*file_paths):
112113
intersphinx_mapping = {
113114
"python": ("https://docs.python.org/3/", None),
114115
"torch": ("https://pytorch.org/docs/stable/", None),
115-
"linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", None),
116+
"linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", "linear_operator_objects.inv"),
117+
# The local mapping here is temporary until we get a new release of linear_operator
116118
}
117119

118120
# Disable docstring inheritance
@@ -237,41 +239,81 @@ def find_version(*file_paths):
237239
]
238240

239241

240-
# -- Function to format typehints ----------------------------------------------
242+
# -- Functions to format typehints ----------------------------------------------
241243
# Adapted from
242244
# https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py
245+
246+
247+
# Helper function
248+
# Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
249+
# For external classes, the format will be e.g. "torch.Tensor"
250+
# For any internal class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
251+
def _convert_internal_and_external_class_to_strings(annotation):
252+
module = annotation.__module__ + "."
253+
if module.split(".")[0] == "gpytorch":
254+
module = "~" + module
255+
elif module == "torch.":
256+
module = "~torch."
257+
elif module == "linear_operator.operators._linear_operator.":
258+
module = "~linear_operator."
259+
elif module == "builtins.":
260+
module = ""
261+
res = f"{module}{annotation.__name__}"
262+
return res
263+
264+
265+
# Convert jaxtyping dimensions into strings
266+
def _dim_to_str(dim):
267+
if isinstance(dim, jaxtyping._array_types._NamedVariadicDim):
268+
return "..."
269+
elif isinstance(dim, jaxtyping._array_types._FixedDim):
270+
res = str(dim.size)
271+
if dim.broadcastable:
272+
res = "#" + res
273+
return res
274+
elif isinstance(dim, jaxtyping._array_types._SymbolicDim):
275+
expr = dim.elem
276+
return f"({expr})"
277+
elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis
278+
return "..."
279+
else:
280+
res = str(dim.name)
281+
if dim.broadcastable:
282+
res = "#" + res
283+
return res
284+
285+
286+
# Function to format type hints
243287
def _process(annotation, config):
244288
"""
245289
A function to convert a type/rtype typehint annotation into a :type:/:rtype: string.
246290
This function is a bit hacky, and specific to the type annotations we use most frequently.
291+
247292
This function is recursive.
248293
"""
249294
# Simple/base case: any string annotation is ready to go
250295
if type(annotation) == str:
251296
return annotation
252297

298+
# Jaxtyping: shaped tensors or linear operator
299+
elif hasattr(annotation, "__module__") and "jaxtyping" == annotation.__module__:
300+
cls_annotation = _convert_internal_and_external_class_to_strings(annotation.array_type)
301+
shape = " x ".join([_dim_to_str(dim) for dim in annotation.dims])
302+
return f"{cls_annotation} ({shape})"
303+
253304
# Convert Ellipsis into "..."
254305
elif annotation == Ellipsis:
255306
return "..."
256307

257-
# Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings
258-
# For external classes, the format will be e.g. "torch.Tensor"
259-
# For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
260-
# For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel"
308+
# Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
261309
elif hasattr(annotation, "__name__"):
262-
module = annotation.__module__ + "."
263-
if module.split(".")[0] == "linear_operator":
264-
if annotation.__name__.endswith("LinearOperator"):
265-
module = "~linear_operator."
266-
elif annotation.__name__.endswith("LinearOperator"):
267-
module = "~linear_operator.operators."
268-
else:
269-
module = "~" + module
270-
elif module.split(".")[0] == "gpytorch":
271-
module = "~" + module
272-
elif module == "builtins.":
273-
module = ""
274-
res = f"{module}{annotation.__name__}"
310+
res = _convert_internal_and_external_class_to_strings(annotation)
311+
312+
elif str(annotation).startswith("typing.Callable"):
313+
if len(annotation.__args__) == 2:
314+
res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]"
315+
else:
316+
res = "Callable"
275317

276318
# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
277319
# Also, convert any Optional[*A*] into "*A*, optional"
@@ -291,33 +333,14 @@ def _process(annotation, config):
291333
args = list(annotation.__args__)
292334
res = "(" + ", ".join(_process(arg, config) for arg in args) + ")"
293335

294-
# Convert any List[*A*] into "list(*A*)"
295-
elif str(annotation).startswith("typing.List"):
296-
arg = annotation.__args__[0]
297-
res = "list(" + _process(arg, config) + ")"
298-
299-
# Convert any List[*A*] into "list(*A*)"
300-
elif str(annotation).startswith("typing.Dict"):
301-
res = str(annotation)
302-
303-
# Convert any Iterable[*A*] into "iterable(*A*)"
304-
elif str(annotation).startswith("typing.Iterable"):
305-
arg = annotation.__args__[0]
306-
res = "iterable(" + _process(arg, config) + ")"
307-
308-
# Handle "Callable"
309-
elif str(annotation).startswith("typing.Callable"):
310-
res = "callable"
311-
312-
# Handle "Any"
313-
elif str(annotation).startswith("typing.Any"):
314-
res = ""
336+
# Convert any List[*A*] or Iterable[*A*] into "[*A*, ...]"
337+
elif str(annotation).startswith("typing.Iterable") or str(annotation).startswith("typing.List"):
338+
arg = list(annotation.__args__)[0]
339+
res = f"[{_process(arg, config)}, ...]"
315340

316-
# Special cases for forward references.
317-
# This is brittle, as it only contains case for a select few forward refs
318-
# All others that aren't caught by this are handled by the default case
319-
elif isinstance(annotation, ForwardRef):
320-
res = str(annotation.__forward_arg__)
341+
# Callable typing annotation
342+
elif str(annotation).startswith("typing."):
343+
return str(annotation)[7:]
321344

322345
# For everything we didn't catch: use the simplist string representation
323346
else:

docs/source/distributions.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ gpytorch.distributions
55
===================================
66

77
GPyTorch distribution objects are essentially the same as torch distribution objects.
8-
For the most part, GpyTorch relies on torch's distribution library.
8+
For the most part, GPyTorch relies on torch's distribution library.
99
However, we offer two custom distributions.
1010

1111
We implement a custom :obj:`~gpytorch.distributions.MultivariateNormal` that accepts

0 commit comments

Comments
 (0)