Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/gwkokab/analysis/multisource/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,33 +156,33 @@ def model_parameters(self) -> list[str]:
all_params_names = []
if ct == "spl":
all_params_names.extend([
"alpha_",
"beta_",
"delta_m1_",
"delta_m2_",
"m1min_",
"m2min_",
"mmax_",
"m1_alpha_",
"m1_delta_",
"m1_high_",
"m1_low_",
"m2_delta_",
"m2_low_",
])
if ct == "bpl":
all_params_names.extend([
"alpha1_",
"alpha2_",
"beta_",
"delta_m1_",
"delta_m2_",
"m1min_",
"m2min_",
"mbreak_",
"mmax_",
"m1_alpha1_",
"m1_alpha2_",
"m1_break_",
"m1_delta_",
"m1_high_",
"m1_low_",
"m2_delta_",
"m2_low_",
])
if ct == "gpl":
all_params_names.extend([
"beta_",
"loc_",
"mmax_",
"mmin_",
"scale_",
"m1_high_",
"m1_loc_",
"m1_low_",
"m1_scale_",
])
if ct == "gg":
all_params_names.extend([
Expand Down
29 changes: 19 additions & 10 deletions src/gwkokab/analysis/subpopulation/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,33 @@ def model_parameters(self) -> list[str]:
for ct, count in component_types_and_count:
all_params_names = [
"beta_",
"delta_m2_",
"m2min_",
"m2_delta_",
"m2_low_",
]

if ct == "spl":
all_params_names.extend(["alpha_", "mmax_", "mmin_"])
all_params_names.extend([
"m1_alpha_",
"m1_low_",
"m1_high_",
])

if ct == "bpl":
all_params_names.extend([
"alpha1_",
"alpha2_",
"m1break_",
"m1max_",
"m1min_",
"m1_alpha1_",
"m1_alpha2_",
"m1_break_",
"m1_high_",
"m1_low_",
])

if ct == "gpl":
all_params_names.extend(["m1_loc_", "m1_scale_", "m1_low_", "m1_high_"])
all_params_names.extend([
"m1_high_",
"m1_loc_",
"m1_low_",
"m1_scale_",
])

if self.use_spin_magnitude_mixture:
all_params_names.extend([
Expand Down Expand Up @@ -319,7 +328,7 @@ def model_parameters(self) -> list[str]:
all_params.extend([(name + ct, count) for name in all_params_names])

extended_params = [
"delta_m1",
"m1_delta",
"log_rate",
"m1max",
"m1min",
Expand Down
58 changes: 29 additions & 29 deletions src/gwkokab/models/hybrids/_ncombination.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def create_broken_powerlaws(
params: Dict[str, Array],
validate_args: Optional[bool] = None,
) -> List[Distribution]:
alpha1_name = "alpha1_" + component_type
alpha2_name = "alpha2_" + component_type
mbreak_name = "m1break_" + component_type
mmax_name = "m1max_" + component_type
mmin_name = "m1min_" + component_type
alpha1_name = parameter_name + "_alpha1_" + component_type
alpha2_name = parameter_name + "_alpha2_" + component_type
mbreak_name = parameter_name + "_break_" + component_type
mmax_name = parameter_name + "_high_" + component_type
mmin_name = parameter_name + "_low_" + component_type
return [
BrokenPowerlaw(
alpha1=_get_parameter(params, f"{alpha1_name}_{i}"), # type: ignore
Expand Down Expand Up @@ -288,9 +288,9 @@ def create_powerlaws(
params: Dict[str, Array],
validate_args: Optional[bool] = None,
) -> List[Distribution]:
alpha_name = "alpha_" + component_type
mmax_name = "mmax_" + component_type
mmin_name = "mmin_" + component_type
alpha_name = parameter_name + "_alpha_" + component_type
mmax_name = parameter_name + "_high_" + component_type
mmin_name = parameter_name + "_low_" + component_type

return [
DoublyTruncatedPowerLaw(
Expand Down Expand Up @@ -468,15 +468,15 @@ def create_smoothed_broken_powerlaws_mass_ratio_powerlaw(
) -> List[Distribution]:
collection = []

alpha1_name = "alpha1_" + component_type
alpha2_name = "alpha2_" + component_type
alpha1_name = "m1_alpha1_" + component_type
alpha2_name = "m1_alpha2_" + component_type
beta_name = "beta_" + component_type
delta_m1_name = "delta_m1_" + component_type
delta_m2_name = "delta_m2_" + component_type
m1min_name = "m1min_" + component_type
m2min_name = "m2min_" + component_type
mbreak_name = "mbreak_" + component_type
mmax_name = "mmax_" + component_type
delta_m1_name = "m1_delta_" + component_type
delta_m2_name = "m2_delta_" + component_type
m1min_name = "m1_low_" + component_type
m2min_name = "m2_low_" + component_type
mbreak_name = "m1_break_" + component_type
mmax_name = "m1_high_" + component_type

for i in range(N):
suffix = f"_{i}"
Expand Down Expand Up @@ -554,11 +554,11 @@ def create_gaussian_primary_mass_ratio(
) -> List[Distribution]:
collection = []

loc_name = "loc_" + component_type
scale_name = "scale_" + component_type
loc_name = "m1_loc_" + component_type
scale_name = "m1_scale_" + component_type
beta_name = "beta_" + component_type
mmin_name = "mmin_" + component_type
mmax_name = "mmax_" + component_type
mmin_name = "m1_low_" + component_type
mmax_name = "m1_high_" + component_type

for i in range(N):
suffix = f"_{i}"
Expand Down Expand Up @@ -590,13 +590,13 @@ def create_smoothed_powerlaw_primary_mass_ratio(
) -> List[Distribution]:
collection = []

alpha_name = "alpha_" + component_type
alpha_name = "m1_alpha_" + component_type
beta_name = "beta_" + component_type
delta_m1_name = "delta_m1_" + component_type
delta_m2_name = "delta_m2_" + component_type
m1min_name = "m1min_" + component_type
m2min_name = "m2min_" + component_type
mmax_name = "mmax_" + component_type
delta_m1_name = "m1_delta_" + component_type
delta_m2_name = "m2_delta_" + component_type
m1min_name = "m1_low_" + component_type
m2min_name = "m2_low_" + component_type
mmax_name = "m1_high_" + component_type

for i in range(N):
suffix = f"_{i}"
Expand Down Expand Up @@ -631,9 +631,9 @@ def create_generic_smoothed_powerlaw_mass_ratio(
) -> List[Distribution]:

beta_name = "beta_" + component_type
delta_m1_name = "delta_m1"
delta_m2_name = "delta_m2_" + component_type
m2min_name = "m2min_" + component_type
delta_m1_name = "m1_delta"
delta_m2_name = "m2_delta_" + component_type
m2min_name = "m2_low_" + component_type

delta_m1 = _get_parameter(params, delta_m1_name)

Expand Down
6 changes: 3 additions & 3 deletions src/gwkokab/models/hybrids/_subpopulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _build_component_distributions(
if component_type == "spl":
_mass_distributions = create_powerlaws(
N=N,
parameter_name=None, # type: ignore # unused parameter
parameter_name="m1",
component_type=component_type,
params=params,
validate_args=validate_args,
Expand All @@ -137,7 +137,7 @@ def _build_component_distributions(
if component_type == "bpl":
_mass_distributions = create_broken_powerlaws(
N=N,
parameter_name=None, # type: ignore # unused parameter
parameter_name="m1",
component_type=component_type,
params=params,
validate_args=validate_args,
Expand Down Expand Up @@ -244,7 +244,7 @@ def SubPopulationModel(
_lambdas.append(1.0 - sum(_lambdas))
lambdas = jnp.stack(_lambdas, axis=-1)

delta_m1 = params.pop("delta_m1")
delta_m1 = params.pop("m1_delta")
log_rate = params.pop("log_rate")
m1max = params.pop("m1max")
m1min = params.pop("m1min")
Expand Down
Loading