Skip to content

Commit 6f7db6d

Browse files
authored
Merge pull request #300 from JaxGaussianProcesses/citations
Add cite functionality
2 parents 6408c68 + edd8aae commit 6f7db6d

File tree

10 files changed

+436
-7
lines changed

10 files changed

+436
-7
lines changed

docs/examples/collapsed_vi.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,22 @@
105105
# <strong data-cite="titsias2009">Titsias (2009)</strong>.
106106

107107
# %%
108-
elbo = jit(gpx.CollapsedELBO(negative=True))
108+
elbo = gpx.CollapsedELBO(negative=True)
109+
110+
# %% [markdown]
111+
# For researchers, GPJax has the capacity to print the bibtex citation for objects such
112+
# as the ELBO through the `cite()` function.
113+
114+
# %%
115+
print(gpx.cite(elbo))
116+
117+
# %% [markdown]
118+
# JIT-compiling expensive-to-compute functions such as the ELBO is
119+
# advisable. This can be achieved by wrapping the function in `jax.jit()`.
120+
121+
# %%
122+
123+
elbo = jit(elbo)
109124

110125
# %% [markdown]
111126
# We now train our model akin to a Gaussian process regression model via the `fit`

docs/examples/graph_kernels.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,23 @@
134134

135135
# %%
136136
likelihood = gpx.Gaussian(num_datapoints=D.n)
137-
prior = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.GraphKernel(laplacian=L))
137+
kernel = gpx.GraphKernel(laplacian=L)
138+
prior = gpx.Prior(mean_function=gpx.Zero(), kernel=kernel)
138139
posterior = prior * likelihood
139140

141+
# %% [markdown]
142+
#
143+
# For researchers and the curious reader, GPJax provides the ability to print the
144+
# bibtex citation for objects such as the graph kernel through the `cite()` function.
145+
146+
# %%
147+
print(gpx.cite(kernel))
148+
149+
# %% [markdown]
150+
#
151+
# With a posterior defined, we can now optimise the model's hyperparameters.
152+
153+
# %%
140154
opt_posterior, training_history = gpx.fit(
141155
model=posterior,
142156
objective=jit(gpx.ConjugateMLL(negative=True)),

docs/examples/regression.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
# these parameters by optimising the marginal log-likelihood (MLL).
180180

181181
# %%
182-
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))
182+
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
183183
negative_mll(posterior, train_data=D)
184184

185185

@@ -188,6 +188,20 @@
188188
# ox.adam(learning_rate=0.01),
189189
# ox.masked(ox.set_to_zero(), static_tree)
190190
# )
191+
# %% [markdown]
192+
# For researchers, GPJax has the capacity to print the bibtex citation for objects such
193+
# as the marginal log-likelihood through the `cite()` function.
194+
195+
# %%
196+
print(gpx.cite(negative_mll))
197+
198+
# %% [markdown]
199+
# JIT-compiling expensive-to-compute functions such as the marginal log-likelihood is
200+
# advisable. This can be achieved by wrapping the function in `jax.jit()`.
201+
202+
# %%
203+
negative_mll = jit(negative_mll)
204+
191205
# %% [markdown]
192206
# Since most optimisers (including here) minimise a given function, we have realised
193207
# the negative marginal log-likelihood and just-in-time (JIT) compiled this to

docs/examples/uncollapsed_vi.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,22 @@
227227
# its negative.
228228

229229
# %%
230-
negative_elbo = jit(gpx.ELBO(negative=True))
230+
negative_elbo = gpx.ELBO(negative=True)
231+
232+
# %% [markdown]
233+
# For researchers, GPJax has the capacity to print the bibtex citation for objects such
234+
# as the ELBO through the `cite()` function.
235+
236+
# %%
237+
print(gpx.cite(negative_elbo))
238+
239+
# %% [markdown]
240+
# JIT-compiling expensive-to-compute functions such as the ELBO is
241+
# advisable. This can be achieved by wrapping the function in `jax.jit()`.
242+
243+
# %%
244+
245+
negative_elbo = jit(negative_elbo)
231246

232247
# %% [markdown]
233248
# ### Mini-batching

docs/refs.bib

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ @book{rasmussen2006gaussian
2525
}
2626

2727
@article{hensman2013gaussian,
28-
title = {Gaussian processes for big data},
28+
title = {{G}aussian processes for big data},
2929
author = {Hensman, James and Fusi, Nicolo and Lawrence, Neil D},
30-
journal = {arXiv preprint arXiv:1309.6835},
30+
journal = {Artificial intelligence and statistics},
3131
year = {2013}
3232
}
3333

gpjax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Module,
1717
param_field,
1818
)
19+
from gpjax.citation import cite
1920
from gpjax.dataset import Dataset
2021
from gpjax.fit import fit
2122
from gpjax.gps import (
@@ -77,6 +78,7 @@
7778
__all__ = [
7879
"Module",
7980
"param_field",
81+
"cite",
8082
"kernels",
8183
"fit",
8284
"Prior",

gpjax/citation.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from dataclasses import (
2+
dataclass,
3+
fields,
4+
)
5+
6+
from beartype.typing import (
7+
Dict,
8+
Union,
9+
)
10+
from jaxlib.xla_extension import PjitFunction
11+
from plum import dispatch
12+
13+
from gpjax.kernels import (
14+
RFF,
15+
ArcCosine,
16+
GraphKernel,
17+
Matern12,
18+
Matern32,
19+
Matern52,
20+
)
21+
from gpjax.objectives import (
22+
ELBO,
23+
CollapsedELBO,
24+
ConjugateMLL,
25+
LogPosteriorDensity,
26+
NonConjugateMLL,
27+
)
28+
29+
MaternKernels = Union[Matern12, Matern32, Matern52]
30+
MLLs = Union[ConjugateMLL, NonConjugateMLL, LogPosteriorDensity]
31+
CitationType = Union[str, Dict[str, str]]
32+
33+
34+
@dataclass(repr=False)
35+
class AbstractCitation:
36+
citation_key: str = None
37+
authors: str = None
38+
title: str = None
39+
year: str = None
40+
41+
def as_str(self) -> str:
42+
citation_str = f"@{self.citation_type}{{{self.citation_key},"
43+
for field in fields(self):
44+
fn = field.name
45+
if fn not in ["citation_type", "citation_key", "notes"]:
46+
citation_str += f"\n{fn} = {{{getattr(self, fn)}}},"
47+
return citation_str + "\n}"
48+
49+
def __repr__(self) -> str:
50+
return repr(self.as_str())
51+
52+
def __str__(self) -> str:
53+
return self.as_str()
54+
55+
56+
class NullCitation(AbstractCitation):
57+
def __str__(self) -> str:
58+
return (
59+
"No citation available. If you think this is an error, please open a pull"
60+
" request."
61+
)
62+
63+
64+
class JittedFnCitation(AbstractCitation):
65+
def __str__(self) -> str:
66+
return "Citation not available for jitted objects."
67+
68+
69+
@dataclass
70+
class PhDThesisCitation(AbstractCitation):
71+
school: str = None
72+
institution: str = None
73+
citation_type: str = "phdthesis"
74+
75+
76+
@dataclass
77+
class PaperCitation(AbstractCitation):
78+
booktitle: str = None
79+
citation_type: str = "inproceedings"
80+
81+
82+
@dataclass
83+
class BookCitation(AbstractCitation):
84+
publisher: str = None
85+
volume: str = None
86+
citation_type: str = "book"
87+
88+
89+
####################
90+
# Default citation
91+
####################
92+
@dispatch
93+
def cite(tree) -> NullCitation:
94+
return NullCitation()
95+
96+
97+
####################
98+
# Default citation
99+
####################
100+
@dispatch
101+
def cite(tree: PjitFunction) -> JittedFnCitation:
102+
return JittedFnCitation()
103+
104+
105+
####################
106+
# Kernel citations
107+
####################
108+
@dispatch
109+
def cite(tree: MaternKernels) -> PhDThesisCitation:
110+
citation = PhDThesisCitation(
111+
citation_key="matern1960SpatialV",
112+
authors="Bertil Matérn",
113+
title=(
114+
"Spatial variation : Stochastic models and their application to some"
115+
" problems in forest surveys and other sampling investigations"
116+
),
117+
year="1960",
118+
school="Stockholm University",
119+
institution="Stockholm University",
120+
)
121+
return citation
122+
123+
124+
@dispatch
125+
def cite(tree: ArcCosine) -> PaperCitation:
126+
return PaperCitation(
127+
citation_key="cho2009kernel",
128+
authors="Cho, Youngmin and Saul, Lawrence",
129+
title="Kernel Methods for Deep Learning",
130+
year="2009",
131+
booktitle="Advances in Neural Information Processing Systems",
132+
)
133+
134+
135+
@dispatch
136+
def cite(tree: GraphKernel) -> PaperCitation:
137+
return PaperCitation(
138+
citation_key="borovitskiy2021matern",
139+
title="Matérn Gaussian Processes on Graphs",
140+
authors=(
141+
"Borovitskiy, Viacheslav and Azangulov, Iskander and Terenin, Alexander and"
142+
" Mostowsky, Peter and Deisenroth, Marc and Durrande, Nicolas"
143+
),
144+
booktitle="International Conference on Artificial Intelligence and Statistics",
145+
year="2021",
146+
)
147+
148+
149+
@dispatch
150+
def cite(tree: RFF) -> PaperCitation:
151+
return PaperCitation(
152+
citation_key="rahimi2007random",
153+
authors="Rahimi, Ali and Recht, Benjamin",
154+
title="Random features for large-scale kernel machines",
155+
year="2007",
156+
booktitle="Advances in neural information processing systems",
157+
citation_type="article",
158+
)
159+
160+
161+
####################
162+
# Objective citations
163+
####################
164+
@dispatch
165+
def cite(tree: MLLs) -> BookCitation:
166+
return BookCitation(
167+
citation_key="rasmussen2006gaussian",
168+
title="Gaussian Processes for Machine Learning",
169+
authors="Rasmussen, Carl Edward and Williams, Christopher K",
170+
year="2006",
171+
publisher="MIT press Cambridge, MA",
172+
volume="2",
173+
)
174+
175+
176+
@dispatch
177+
def cite(tree: CollapsedELBO) -> PaperCitation:
178+
return PaperCitation(
179+
citation_key="titsias2009variational",
180+
title="Variational learning of inducing variables in sparse Gaussian processes",
181+
authors="Titsias, Michalis",
182+
year="2009",
183+
booktitle="International Conference on Artificial Intelligence and Statistics",
184+
)
185+
186+
187+
@dispatch
188+
def cite(tree: ELBO) -> PaperCitation:
189+
return PaperCitation(
190+
citation_key="hensman2013gaussian",
191+
title="Gaussian Processes for Big Data",
192+
authors="Hensman, James and Fusi, Nicolo and Lawrence, Neil D",
193+
year="2013",
194+
booktitle="Uncertainty in Artificial Intelligence",
195+
citation_type="article",
196+
)

poetry.lock

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ tensorflow-probability = "^0.19.0"
2525
orbax-checkpoint = "^0.2.0"
2626
beartype = "^0.13.1"
2727
jaxlib = "0.4.7" # Temporary fix: https://github.com/google/jax/issues/15951
28+
plum-dispatch = "^2.1.0"
2829

2930
[tool.poetry.group.test.dependencies]
3031
pytest = "^7.2.2"
@@ -160,11 +161,13 @@ convention = "numpy"
160161
"gpjax/__init__.py" = ['I', 'F401', 'E402', 'D104']
161162
"gpjax/progress_bar.py" = ["TCH004"]
162163
"gpjax/scan.py" = ["PLR0913"]
164+
"gpjax/citation.py" = ["F811"]
163165
"tests/test_base/test_module.py" = ["PLR0915"]
164166
"tests/test_linops/test_linear_operator.py" = ["PLR0913"]
165167
"tests/test_objectives.py" = ["PLR0913"]
166168
"docs/examples/barycentres.py" = ["PLR0913"]
167169

170+
168171
[tool.isort]
169172
profile = "black"
170173
combine_as_imports = true

0 commit comments

Comments
 (0)