Skip to content

Commit 3b349f0

Browse files
committed
elements by diameter and discovery
1 parent d214be8 commit 3b349f0

File tree

1 file changed

+136
-83
lines changed

1 file changed

+136
-83
lines changed

src/adam_impact_study/analysis/plots.py

Lines changed: 136 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,13 +2458,21 @@ def make_analysis_plots(
24582458
window_results: WindowResult,
24592459
out_dir: str,
24602460
) -> None:
2461+
2462+
for diameter in summary.orbit.diameter.unique().to_pylist():
2463+
fig, ax = plot_observed_vs_unobserved_elements(summary, diameter=diameter)
2464+
fig.savefig(
2465+
os.path.join(out_dir, f"observed_vs_unobserved_elements_{diameter}km.jpg"),
2466+
bbox_inches="tight",
2467+
dpi=200,
2468+
)
2469+
plt.close(fig)
2470+
24612471
fig, ax = plot_discovered_by_diameter_impact_period(
24622472
summary,
24632473
period="5year",
2464-
# max_impact_time=Timestamp.from_iso8601(["2070-01-01"])
24652474
)
24662475
fig.savefig(
2467-
# os.path.join(out_dir, "discovered_by_diameter_5year_2070.jpg"),
24682476
os.path.join(out_dir, "discovered_by_diameter_5year.jpg"),
24692477
bbox_inches="tight",
24702478
dpi=200,
@@ -2492,7 +2500,8 @@ def make_analysis_plots(
24922500
plt.close(fig)
24932501

24942502
fig, ax = plot_max_impact_probability_histograms_by_diameter_decade(
2495-
summary, include_undiscovered=True
2503+
summary, include_undiscovered=True,
2504+
x_log_scale=True
24962505
)
24972506
fig.savefig(
24982507
os.path.join(
@@ -2501,13 +2510,13 @@ def make_analysis_plots(
25012510
bbox_inches="tight",
25022511
dpi=200,
25032512
)
2504-
plt.close(fig)
25052513
logger.info(
25062514
"Generated max impact probability histograms by diameter decade plot (all)"
25072515
)
25082516

25092517
fig, ax = plot_max_impact_probability_histograms_by_diameter_decade(
2510-
summary, include_undiscovered=False
2518+
summary, include_undiscovered=False,
2519+
x_log_scale=True
25112520
)
25122521
fig.savefig(
25132522
os.path.join(
@@ -2670,6 +2679,8 @@ def make_analysis_plots(
26702679
def plot_max_impact_probability_histograms_by_diameter_decade(
26712680
summary: ImpactorResultSummary,
26722681
include_undiscovered: bool = False,
2682+
y_log_scale: bool = False,
2683+
x_log_scale: bool = False,
26732684
) -> Tuple[plt.Figure, plt.Axes]:
26742685
"""
26752686
Plot a grid of histograms showing the maximum impact probability distribution
@@ -2683,6 +2694,11 @@ def plot_max_impact_probability_histograms_by_diameter_decade(
26832694
include_undiscovered : bool, optional
26842695
Whether to include undiscovered objects in the histograms. If False,
26852696
only discovered objects are included. Default is False.
2697+
y_log_scale : bool, optional
2698+
Whether to use a log scale for the y-axis. Default is False.
2699+
x_log_scale : bool, optional
2700+
Whether to use a log scale for the x-axis. Default is False.
2701+
26862702
26872703
Returns
26882704
-------
@@ -2744,14 +2760,29 @@ def plot_max_impact_probability_histograms_by_diameter_decade(
27442760
continue
27452761

27462762
# Create histogram
2747-
ax.hist(
2748-
max_impact_probs,
2749-
range=(0, 1),
2750-
bins=20,
2751-
# bins="fd",
2752-
color=colors[i],
2753-
alpha=0.7,
2754-
)
2763+
if x_log_scale:
2764+
bins = [0.00001, 0.0001, 0.001, 0.01, 0.1, 1]
2765+
ax.hist(
2766+
np.where(max_impact_probs == 0, 0.00001, max_impact_probs),
2767+
bins=bins,
2768+
color=colors[i],
2769+
alpha=0.7,
2770+
)
2771+
ax.set_xscale('log')
2772+
# Set the x-axis ticks to match our bin edges
2773+
ax.set_xticks(bins)
2774+
# Format the tick labels to be more readable
2775+
ax.set_xticklabels([f"{x:.0e}" for x in bins])
2776+
# Rotate the tick labels
2777+
ax.tick_params(axis='x', rotation=45)
2778+
else:
2779+
ax.hist(
2780+
max_impact_probs,
2781+
range=(0, 1),
2782+
bins=20,
2783+
color=colors[i],
2784+
alpha=0.7,
2785+
)
27552786

27562787
# Get the y-limit
27572788
y_limit = ax.get_ylim()[1]
@@ -2766,23 +2797,29 @@ def plot_max_impact_probability_histograms_by_diameter_decade(
27662797

27672798
# Only add y-axis label to leftmost column
27682799
if j == 0:
2769-
ax.set_ylabel("Count (log scale)")
2800+
if y_log_scale:
2801+
ax.set_ylabel("Count (log scale)")
2802+
else:
2803+
ax.set_ylabel("Count")
27702804

27712805
# Collect all the y-limits and set them all to be the same max value
27722806
axes_flat = axes.flatten()
27732807
for ax in axes_flat:
27742808
ax.set_ylim(0.1, max_y * 1.1) # Start at 0.1 to avoid log(0) issues
27752809

2776-
# Make the y axis log scale with proper tick formatting
2777-
for ax in axes_flat:
2778-
ax.set_yscale("log")
2779-
# Set major ticks at powers of 10
2780-
ax.yaxis.set_major_locator(plt.LogLocator(base=10, numticks=5))
2781-
# Set minor ticks between major ticks
2782-
ax.yaxis.set_minor_locator(plt.LogLocator(base=10, subs=np.arange(2, 10) * 0.1, numticks=5))
2783-
# Format the tick labels to be more readable
2784-
ax.yaxis.set_major_formatter(plt.ScalarFormatter())
2785-
ax.grid(True, which='major', alpha=0.3, linestyle='--')
2810+
if y_log_scale:
2811+
# Make the y axis log scale with proper tick formatting
2812+
for ax in axes_flat:
2813+
ax.set_yscale("log")
2814+
# Set major ticks at powers of 10
2815+
ax.yaxis.set_major_locator(plt.LogLocator(base=10, numticks=5))
2816+
# Set minor ticks between major ticks
2817+
ax.yaxis.set_minor_locator(
2818+
plt.LogLocator(base=10, subs=np.arange(2, 10) * 0.1, numticks=5)
2819+
)
2820+
# Format the tick labels to be more readable
2821+
ax.yaxis.set_major_formatter(plt.ScalarFormatter())
2822+
ax.grid(True, which="major", alpha=0.3, linestyle="--")
27862823

27872824
# Add a single legend for all plots
27882825
legend_elements = [
@@ -3084,7 +3121,7 @@ def plot_observed_vs_unobserved_elements(
30843121
- Discovered (blue)
30853122
- Observed but not discovered (yellow)
30863123
- Unobserved (red)
3087-
3124+
30883125
Parameters
30893126
----------
30903127
summary : ImpactorResultSummary
@@ -3097,62 +3134,75 @@ def plot_observed_vs_unobserved_elements(
30973134
Tuple[plt.Figure, plt.Axes]
30983135
The figure and axes objects for the plot.
30993136
"""
3137+
summary = summary.apply_mask(summary.complete())
31003138
# Filter to only include the given diameter
3101-
summary = summary.apply_mask(pc.equal(summary.orbit.diameter, diameter))
31023139

3103-
if len(summary) == 0:
3104-
print(f"No data found for diameter {diameter} km.")
3105-
return
3140+
summary = summary.apply_mask(pc.equal(summary.orbit.diameter, diameter))
31063141

31073142
orbits_at_diameter = summary.apply_mask(pc.equal(summary.orbit.diameter, diameter))
3108-
print("num orbits:", len(orbits_at_diameter))
3109-
3143+
31103144
# Get Keplerian coordinates
31113145
kep_coordinates = summary.orbit.coordinates.to_keplerian()
31123146
a_au = kep_coordinates.a.to_numpy(zero_copy_only=False)
31133147
i_deg = kep_coordinates.i.to_numpy(zero_copy_only=False)
31143148
e = kep_coordinates.e.to_numpy(zero_copy_only=False)
31153149

3116-
#print discovered objedts by those with a non null discovery time
3117-
discovered_objects = summary.apply_mask(pc.invert(pc.is_null(summary.discovery_time.days)))
3118-
print(discovered_objects)
3119-
#do the same for undiscovered by the null discovery time
3120-
undiscovered_objects = summary.apply_mask(pc.is_null(summary.discovery_time.days))
3121-
print(undiscovered_objects)
3122-
3123-
#print the object names that are observed but not discovered
3150+
# print discovered objedts by those with a non null discovery time
3151+
discovered_objects_mask = pc.invert(
3152+
pc.is_null(summary.discovery_time.days)
3153+
).to_numpy(zero_copy_only=False)
3154+
# do the same for undiscovered by the null discovery time
3155+
observed_not_discovered_objects_mask = pc.and_(
3156+
pc.is_null(summary.discovery_time.days),
3157+
pc.greater(summary.observations, 0),
3158+
).to_numpy(zero_copy_only=False)
3159+
unobserved_objects_mask = pc.and_(
3160+
pc.is_null(summary.discovery_time.days), pc.equal(summary.observations, 0)
3161+
).to_numpy(zero_copy_only=False)
31243162

3125-
assert len(orbits_discovered) + len(orbits_observed_not_discovered) + len(orbits_unobserved) == len(orbits_at_diameter)
3163+
assert np.sum(discovered_objects_mask) + np.sum(
3164+
observed_not_discovered_objects_mask
3165+
) + np.sum(unobserved_objects_mask) == len(orbits_at_diameter)
31263166
# Create the plots
31273167
fig, axes = plt.subplots(1, 2, dpi=200, figsize=(18, 7))
31283168

3169+
scatter_dot_size = 10
3170+
scatter_dot_alpha = 0.3
3171+
# Uses contrasting colors that don't include yellow
3172+
# colors = plt.cm.coolwarm(np.linspace(0, 1, 3))
3173+
# colors = plt.cm.viridis(np.linspace(0, 1, 3))
3174+
colors = ["blue", "green", "red"]
3175+
31293176
# --- Plot a vs i ---
31303177
# Plot discovered (blue) first
31313178
axes[0].scatter(
3132-
a_au[discovered_mask],
3133-
i_deg[discovered_mask],
3134-
c='blue',
3135-
alpha=0.6,
3136-
label='Discovered',
3137-
s=20
3179+
a_au[discovered_objects_mask],
3180+
i_deg[discovered_objects_mask],
3181+
c=colors[0],
3182+
alpha=scatter_dot_alpha,
3183+
linewidths=0,
3184+
label="Discovered",
3185+
s=scatter_dot_size,
31383186
)
31393187
# Plot observed but not discovered (yellow) second
31403188
axes[0].scatter(
3141-
a_au[observed_not_discovered_mask],
3142-
i_deg[observed_not_discovered_mask],
3143-
c='yellow',
3144-
alpha=0.6,
3145-
label='Observed (Not Discovered)',
3146-
s=20
3189+
a_au[observed_not_discovered_objects_mask],
3190+
i_deg[observed_not_discovered_objects_mask],
3191+
c=colors[1],
3192+
alpha=scatter_dot_alpha,
3193+
linewidths=0,
3194+
label="Observed (Not Discovered)",
3195+
s=scatter_dot_size,
31473196
)
31483197
# Plot unobserved (red) last
31493198
axes[0].scatter(
3150-
a_au[unobserved_mask],
3151-
i_deg[unobserved_mask],
3152-
c='red',
3153-
alpha=0.6,
3154-
label='Unobserved',
3155-
s=20
3199+
a_au[unobserved_objects_mask],
3200+
i_deg[unobserved_objects_mask],
3201+
c=colors[2],
3202+
alpha=scatter_dot_alpha,
3203+
linewidths=0,
3204+
label="Unobserved",
3205+
s=scatter_dot_size,
31563206
)
31573207

31583208
axes[0].set_xlabel("Semimajor Axis (a) [AU]")
@@ -3164,30 +3214,33 @@ def plot_observed_vs_unobserved_elements(
31643214
# --- Plot a vs e ---
31653215
# Plot discovered (blue) first
31663216
axes[1].scatter(
3167-
a_au[discovered_mask],
3168-
e[discovered_mask],
3169-
c='blue',
3170-
alpha=0.6,
3171-
label='Discovered',
3172-
s=20
3217+
a_au[discovered_objects_mask],
3218+
e[discovered_objects_mask],
3219+
c=colors[0],
3220+
alpha=scatter_dot_alpha,
3221+
linewidths=0,
3222+
label="Discovered",
3223+
s=scatter_dot_size,
31733224
)
31743225
# Plot observed but not discovered (yellow) second
31753226
axes[1].scatter(
3176-
a_au[observed_not_discovered_mask],
3177-
e[observed_not_discovered_mask],
3178-
c='yellow',
3179-
alpha=0.6,
3180-
label='Observed (Not Discovered)',
3181-
s=20
3227+
a_au[observed_not_discovered_objects_mask],
3228+
e[observed_not_discovered_objects_mask],
3229+
c=colors[1],
3230+
alpha=scatter_dot_alpha,
3231+
linewidths=0,
3232+
label="Observed (Not Discovered)",
3233+
s=scatter_dot_size,
31823234
)
31833235
# Plot unobserved (red) last
31843236
axes[1].scatter(
3185-
a_au[unobserved_mask],
3186-
e[unobserved_mask],
3187-
c='red',
3188-
alpha=0.6,
3189-
label='Unobserved',
3190-
s=20
3237+
a_au[unobserved_objects_mask],
3238+
e[unobserved_objects_mask],
3239+
c=colors[2],
3240+
alpha=scatter_dot_alpha,
3241+
linewidths=0,
3242+
label="Unobserved",
3243+
s=scatter_dot_size,
31913244
)
31923245

31933246
axes[1].set_xlabel("Semimajor Axis (a) [AU]")
@@ -3197,16 +3250,16 @@ def plot_observed_vs_unobserved_elements(
31973250
axes[1].grid(True, alpha=0.3)
31983251

31993252
# Add overall title with statistics
3200-
n_total = len(observed_mask)
3201-
n_discovered = discovered_mask.sum()
3202-
n_observed_not_discovered = observed_not_discovered_mask.sum()
3203-
n_unobserved = unobserved_mask.sum()
3253+
n_total = len(orbits_at_diameter)
3254+
n_discovered = np.sum(discovered_objects_mask)
3255+
n_observed_not_discovered = np.sum(observed_not_discovered_objects_mask)
3256+
n_unobserved = np.sum(unobserved_objects_mask)
32043257
assert n_total == n_discovered + n_observed_not_discovered + n_unobserved
3205-
3258+
32063259
percent_discovered = (n_discovered / n_total) * 100
32073260
percent_observed_not_discovered = (n_observed_not_discovered / n_total) * 100
32083261
percent_unobserved = (n_unobserved / n_total) * 100
3209-
3262+
32103263
fig.suptitle(
32113264
f"Distribution of Objects (Diameter: {diameter} km)\n"
32123265
f"Total Objects: {n_total}, "
@@ -3216,4 +3269,4 @@ def plot_observed_vs_unobserved_elements(
32163269
)
32173270

32183271
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
3219-
return fig, axes
3272+
return fig, axes

0 commit comments

Comments
 (0)