Skip to content

Commit b6dbd5c

Browse files
committed
upgrade API to torchlambertw=0.0.3; 10x faster MLE
1 parent a24d00d commit b6dbd5c

File tree

7 files changed

+29
-13
lines changed

7 files changed

+29
-13
lines changed

pylambertw/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Init for module"""
2+
from ._version import __version__
3+
4+
__all__ = ["__version__"]

pylambertw/_version.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Version."""
2+
3+
__version__ = "0.0.2"

pylambertw/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ def __repr__(self):
158158
def tau(self):
159159
"""Converts Theta (distribution dependent) to Tau (transformation only)."""
160160

161-
distr_constr = lwd.get_distribution_constructor(self.distribution_name)
161+
distr_constr = lwd.utils.get_distribution_constructor(self.distribution_name)
162162
distr = distr_constr(**self.beta)
163163

164164
return Tau(
165165
loc=distr.mean.numpy()
166-
if lwd.is_location_family(self.distribution_name)
166+
if lwd.utils.is_location_family(self.distribution_name)
167167
else 0.0,
168168
scale=distr.stddev.numpy(),
169169
lambertw_params=self.lambertw_params,

pylambertw/mle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
self.distribution_name = distribution_name
5151
self.distribution_constructor = (
5252
distribution_constructor
53-
or lwd.get_distribution_constructor(self.distribution_name)
53+
or lwd.utils.get_distribution_constructor(self.distribution_name)
5454
)
5555
self.lambertw_type = base.LambertWType(lambertw_type)
5656
self.max_iter = max_iter
@@ -73,14 +73,14 @@ def _initialize_params(self, data: np.ndarray):
7373
if self.lambertw_type == base.LambertWType.H:
7474
self.igmm = igmm.IGMM(
7575
lambertw_type=self.lambertw_type,
76-
location_family=lwd.is_location_family(self.distribution_name),
76+
location_family=lwd.utils.is_location_family(self.distribution_name),
7777
)
7878
self.igmm.fit(data)
7979
x_init = self.igmm.transform(data)
8080

8181
lambertw_params_init = self.igmm.tau.lambertw_params
8282
else:
83-
if lwd.is_location_family(self.distribution_name):
83+
if lwd.utils.is_location_family(self.distribution_name):
8484
# Default to Normal distriubtion for location family.
8585
params_data = ud.estimate_params(data, "Normal")
8686
loc_init = params_data["loc"]
@@ -92,7 +92,7 @@ def _initialize_params(self, data: np.ndarray):
9292
scale_init = 1.0 / params_data["rate"]
9393

9494
z_init = (data - loc_init) / scale_init
95-
if lwd.is_location_family(self.distribution_name):
95+
if lwd.utils.is_location_family(self.distribution_name):
9696
gamma_init = igmm.gamma_taylor(z_init)
9797
else:
9898
gamma_init = 0.01

pylambertw/tests/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ def test_estimate_params(dist_name):
2525
rng = np.random.RandomState(42)
2626
x = rng.normal(100)
2727
params = ud.estimate_params(x, dist_name)
28-
constr = lwd.get_distribution_constructor(dist_name)
29-
param_names = lwd.get_distribution_args(constr)
28+
constr = lwd.utils.get_distribution_constructor(dist_name)
29+
param_names = lwd.utils.get_distribution_args(constr)
3030
assert set(params.keys()) == set(param_names)

pylambertw/utils/distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def torch_sigmoid(x: torch.tensor) -> torch.tensor:
3636
def get_params_activations(distribution_name: str) -> Dict[str, Callable]:
3737
"""Get activation functions for each distribution parameters."""
3838
assert isinstance(distribution_name, str)
39-
distr_constr = lwd.get_distribution_constructor(distribution_name)
40-
param_names = lwd.get_distribution_args(distr_constr)
39+
distr_constr = lwd.utils.get_distribution_constructor(distribution_name)
40+
param_names = lwd.utils.get_distribution_args(distr_constr)
4141

4242
act_fns = {p: (torch_linear, linear_inverse) for p in param_names}
4343

setup.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
from setuptools import find_packages, setup
2+
import re
3+
4+
_VERSION_FILE = "pylambertw/_version.py"
5+
verstrline = open(_VERSION_FILE, "rt").read()
6+
_VERSION = r"^__version__ = ['\"]([^'\"]*)['\"]"
7+
mo = re.search(_VERSION, verstrline, re.M)
8+
if mo:
9+
verstr = mo.group(1)
10+
else:
11+
raise RuntimeError("Unable to find version string in %s." % (_VERSION_FILE,))
212

313
pkg_descr = """
414
Python implementation of the Lambert W x F framework for analyzing skewed, heavy-tailed distribution
515
with an sklearn interface and torch based maximum likelihood estimation (MLE).
616
"""
717

8-
918
setup(
1019
name="pylambertw",
11-
version="0.0.1",
20+
version=verstr,
1221
url="https://github.com/gmgeorg/pylambertw.git",
1322
author="Georg M. Goerg",
1423
author_email="[email protected]",
@@ -25,6 +34,6 @@
2534
"tqdm>=4.46.1",
2635
"dataclasses>=0.6",
2736
"scikit-learn>=1.0.1",
28-
"torchlambertw @ git+ssh://[email protected]/gmgeorg/torchlambertw.git#egg=torchlambertw-0.0.1",
37+
"torchlambertw @ git+ssh://[email protected]/gmgeorg/torchlambertw.git#egg=torchlambertw-0.0.3",
2938
],
3039
)

0 commit comments

Comments
 (0)