Skip to content
This repository was archived by the owner on Aug 8, 2024. It is now read-only.

Commit e54cd6e

Browse files
committed
correct arg name
Signed-off-by: Nathaniel Starkman (@nstarman) <[email protected]>
1 parent 690a3f0 commit e54cd6e

File tree

3 files changed

+45
-41
lines changed

3 files changed

+45
-41
lines changed

sample_scf/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def __init__(
4545
self,
4646
potential: SCFPotential,
4747
momtype: int = 1,
48-
a: float = None,
49-
b: float = None,
48+
a: T.Optional[float] = None,
49+
b: T.Optional[float] = None,
5050
xtol: float = 1e-14,
5151
badvalue: T.Optional[float] = None,
52-
name: str = None,
53-
longname: str = None,
54-
shapes: tuple = None,
55-
extradoc: str = None,
56-
seed: int = None,
52+
name: T.Optional[str] = None,
53+
longname: T.Optional[str] = None,
54+
shapes: T.Optional[T.Tuple[int, ...]] = None,
55+
extradoc: T.Optional[str] = None,
56+
seed: T.Optional[int] = None,
5757
):
5858
super().__init__(
5959
momtype=momtype,

sample_scf/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,25 @@ class SCFSampler(SCFSamplerBase): # metaclass=SCFSamplerSwitch
9696

9797
def __init__(
9898
self,
99-
pot: SCFPotential,
99+
potential: SCFPotential,
100100
method: T.Union[T.Literal["interp", "exact"], MethodsMapping],
101101
**kwargs: T.Any
102102
) -> None:
103-
super().__init__(pot)
103+
super().__init__(potential)
104104

105105
if isinstance(method, Mapping):
106106
sampler = None
107-
rsampler = method["r"](pot, **kwargs)
108-
thetasampler = method["theta"](pot, **kwargs)
109-
phisampler = method["phi"](pot, **kwargs)
107+
rsampler = method["r"](potential, **kwargs)
108+
thetasampler = method["theta"](potential, **kwargs)
109+
phisampler = method["phi"](potential, **kwargs)
110110
else:
111-
sampler_cls: rv_potential
111+
sampler_cls: T.Type[SCFSamplerBase]
112112
if method == "interp":
113113
sampler_cls = SCFSamplerIntrp
114114
elif method == "exact":
115115
sampler_cls = SCFSamplerExact
116116

117-
sampler = sampler_cls(pot, **kwargs)
117+
sampler = sampler_cls(potential, **kwargs)
118118
rsampler = sampler.rsampler
119119
thetasampler = sampler.thetasampler
120120
phisampler = sampler.phisampler

sample_scf/sample_exact.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ class SCFSampler(SCFSamplerBase):
5555
5656
"""
5757

58-
def __init__(self, pot: SCFPotential, **kw: T.Any) -> None:
59-
self._rsampler = SCFRSampler(pot)
58+
def __init__(self, potential: SCFPotential, **kw: T.Any) -> None:
59+
self._rsampler = SCFRSampler(potential)
6060
# not fixed r, theta. slower!
61-
self._thetasampler = SCFThetaSampler_of_r(pot, r=None)
62-
self._phisampler = SCFPhiSampler_of_rtheta(pot, r=None, theta=None)
61+
self._thetasampler = SCFThetaSampler_of_r(potential) # r=None
62+
self._phisampler = SCFPhiSampler_of_rtheta(potential) # r=None, theta=None
6363

6464
# /def
6565

@@ -89,7 +89,7 @@ def __init__(self, potential: SCFPotential, **kw: T.Any) -> None:
8989
# /def
9090

9191
def _cdf(self, r: npt.ArrayLike, *args: T.Any, **kw: T.Any) -> NDArray64:
92-
mass: NDArray64 = self._pot._mass(r)
92+
mass: NDArray64 = self._potential._mass(r)
9393
# (self._scfmass(zeta) - self._mi) / (self._mf - self._mi)
9494
# TODO! is this normalization even necessary?
9595
return mass
@@ -139,7 +139,7 @@ def Qls(self, r: float) -> NDArray64:
139139
Ql : ndarray
140140
141141
"""
142-
Qls: NDArray64 = thetaQls(self._pot, r)
142+
Qls: NDArray64 = thetaQls(self._potential, r)
143143
return Qls
144144

145145
# /def
@@ -212,9 +212,10 @@ def _cdf(self, theta: npt.ArrayLike, *args: T.Any) -> NDArray64:
212212

213213

214214
class SCFThetaSampler_of_r(SCFThetaSamplerBase):
215-
def _cdf(self, theta: NDArray64, *args: T.Any, r: float) -> NDArray64:
215+
216+
def _cdf(self, theta: NDArray64, *args: T.Any, r: T.Optional[float] = None) -> NDArray64:
216217
x = x_of_theta(theta)
217-
Qlsatr = self.Qls(r)
218+
Qlsatr = self.Qls(T.cast(float, r))
218219

219220
# l = 0
220221
term0 = (1.0 + x) / 2.0
@@ -276,7 +277,26 @@ def __init__(self, potential: SCFPotential, **kw: T.Any) -> None:
276277

277278
# @functools.lru_cache()
278279
def RSms(self, r: float, theta: float) -> T.Tuple[NDArray64, NDArray64]:
279-
return phiRSms(self._pot, r, theta)
280+
return phiRSms(self._potential, r, theta)
281+
282+
# /def
283+
284+
def _cdf(self, phi: NDArray64, *args: T.Any, **kw: T.Any) -> NDArray64:
285+
Rm = self._Rm
286+
Sm = self._Sm
287+
288+
# l = 0
289+
term0: NDArray64 = phi / (2 * np.pi)
290+
291+
# l = 1+
292+
factor = 1 / Rm[0] # R0
293+
ms = np.arange(1, Rm.shape[1] if len(Rm.shape) > 1 else 2)
294+
term1p = np.sum(
295+
(Rm[1:] * np.sin(ms * phi) + Sm[1:] * (1 - np.cos(ms * phi))) / (2 * np.pi * ms),
296+
)
297+
298+
cdf: NDArray64 = term0 + factor * term1p
299+
return cdf
280300

281301
# /def
282302

@@ -302,22 +322,6 @@ def __init__(self, potential: SCFPotential, r: float, theta: float, **kw: T.Any)
302322

303323
# /def
304324

305-
def _cdf(self, phi: NDArray64, *args: T.Any, **kw: T.Any) -> NDArray64:
306-
Rm, Sm = self._Rm, self._Sm
307-
308-
# l = 0
309-
term0: NDArray64 = phi / (2 * np.pi)
310-
311-
# l = 1+
312-
factor = 1 / Rm[0] # R0
313-
ms = np.arange(1, Rm.shape[1])
314-
term1p = np.sum(
315-
(Rm[1:] * np.sin(ms * phi) + Sm[1:] * (1 - np.cos(ms * phi))) / (2 * np.pi * ms),
316-
)
317-
318-
cdf: NDArray64 = term0 + factor * term1p
319-
return cdf
320-
321325
def cdf(self, phi: NDArray64, *args: T.Any, **kw: T.Any) -> NDArray64:
322326
return self._cdf(phi, *args, **kw)
323327

@@ -331,8 +335,8 @@ class SCFPhiSampler_of_rtheta(SCFPhiSamplerBase):
331335
_Rm: T.Optional[NDArray64]
332336
_Sm: T.Optional[NDArray64]
333337

334-
def _cdf(self, phi: npt.ArrayLike, *args: T.Any, r: float, theta: float) -> NDArray64:
335-
self._Rm, self._Sm = self.RSms(float(r), float(theta))
338+
def _cdf(self, phi: npt.ArrayLike, *args: T.Any, r: T.Optional[float]=None, theta: T.Optional[float]=None) -> NDArray64:
339+
self._Rm, self._Sm = self.RSms(T.cast(float, r), T.cast(float, theta))
336340
cdf: NDArray64 = super()._cdf(phi, *args)
337341
self._Rm, self._Sm = None, None
338342
return cdf

0 commit comments

Comments
 (0)