Skip to content

Commit 018964e

Browse files
committed
Generate 10 principal components for PCA.
1 parent b82d9b5 commit 018964e

3 files changed

Lines changed: 31 additions & 35 deletions

File tree

conf/base/parameters_experimental.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pipelines:
2828

2929
reducer:
3030
name: PCA
31-
n_components: 2
31+
n_components: 10
3232

3333
scaler:
3434
name: SparsityAwareScaler
@@ -198,7 +198,7 @@ pipelines:
198198
# <<: *knn5d_pacmap_besthdbscanflat
199199
# reducer:
200200
# name: PCA
201-
# n_components: 2
201+
# n_components: 10
202202
# random_state: ${globals:random_state}
203203

204204
knn5d_pacmap_bestkmeans: &knn5d_pacmap_bestkmeans
@@ -236,7 +236,7 @@ pipelines:
236236
<<: *knn5d_pacmap_bestkmeans
237237
reducer:
238238
name: PCA
239-
n_components: 2
239+
n_components: 10
240240
random_state: ${globals:random_state}
241241

242242
mean_umap_bestkmeans:

src/kedro_polis_classic/pipelines/experimental/nodes.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,9 @@ def _create_scatter_plot(
321321
# Update axis labels
322322
fig.update_layout(
323323
scene=dict(
324-
xaxis_title=f"{str(x_col).upper()} Component",
325-
yaxis_title=f"{str(y_col).upper()} Component",
326-
zaxis_title=f"{str(z_col).upper()} Component",
324+
xaxis_title=f"{str(x_col).upper()}",
325+
yaxis_title=f"{str(y_col).upper()}",
326+
zaxis_title=f"{str(z_col).upper()}",
327327
),
328328
width=800,
329329
height=600,
@@ -360,8 +360,8 @@ def _create_scatter_plot(
360360

361361
# Update axis labels and layout
362362
fig.update_layout(
363-
xaxis_title=f"{str(x_col).upper()} Component",
364-
yaxis_title=f"{str(y_col).upper()} Component",
363+
xaxis_title=f"{str(x_col).upper()}",
364+
yaxis_title=f"{str(y_col).upper()}",
365365
width=800,
366366
height=600,
367367
plot_bgcolor="white",
@@ -408,10 +408,7 @@ def create_scatter_plot(
408408
if isinstance(filter_output, np.ndarray):
409409
# Create generic column names based on dimensions
410410
n_components = filter_output.shape[1] if len(filter_output.shape) > 1 else 1
411-
if n_components <= 3:
412-
column_names = ["x", "y", "z"][:n_components]
413-
else:
414-
column_names = [f"PC{i + 1}" for i in range(n_components)]
411+
column_names = [f"comp{i + 1}" for i in range(n_components)]
415412

416413
# Create DataFrame with actual participant IDs as index
417414
data = pd.DataFrame(
@@ -424,6 +421,10 @@ def create_scatter_plot(
424421
data = filter_output.copy()
425422
data.index = included_participant_ids
426423

424+
# For plotting, only use the first 2 components (even if more are available)
425+
if len(data.columns) > 2:
426+
data = data.iloc[:, :2]
427+
427428
# Convert cluster labels to pandas Series of strings for categorical coloring
428429
# Make sure the cluster labels have the same index as the data DataFrame
429430
if isinstance(clusterer_output, np.ndarray):
@@ -488,10 +489,7 @@ def create_scatter_plot_by_participant_id(
488489
if isinstance(filter_output, np.ndarray):
489490
# Create generic column names based on dimensions
490491
n_components = filter_output.shape[1] if len(filter_output.shape) > 1 else 1
491-
if n_components <= 3:
492-
column_names = ["x", "y", "z"][:n_components]
493-
else:
494-
column_names = [f"PC{i + 1}" for i in range(n_components)]
492+
column_names = [f"comp{i + 1}" for i in range(n_components)]
495493

496494
# Create DataFrame with actual participant IDs as index
497495
data = pd.DataFrame(
@@ -504,6 +502,10 @@ def create_scatter_plot_by_participant_id(
504502
data = filter_output.copy()
505503
data.index = included_participant_ids
506504

505+
# For plotting, only use the first 2 components (even if more are available)
506+
if len(data.columns) > 2:
507+
data = data.iloc[:, :2]
508+
507509
# Get participant IDs as numeric values for continuous color scale
508510
participant_ids = pd.Series(data.index, index=data.index)
509511

@@ -552,10 +554,7 @@ def create_scatter_plot_by_vote_proportions(
552554
if isinstance(filter_output, np.ndarray):
553555
# Create generic column names based on dimensions
554556
n_components = filter_output.shape[1] if len(filter_output.shape) > 1 else 1
555-
if n_components <= 3:
556-
column_names = ["x", "y", "z"][:n_components]
557-
else:
558-
column_names = [f"PC{i + 1}" for i in range(n_components)]
557+
column_names = [f"comp{i + 1}" for i in range(n_components)]
559558

560559
# Create DataFrame with actual participant IDs as index
561560
data = pd.DataFrame(
@@ -568,6 +567,10 @@ def create_scatter_plot_by_vote_proportions(
568567
data = filter_output.copy()
569568
data.index = included_participant_ids
570569

570+
# For plotting, only use the first 2 components (even if more are available)
571+
if len(data.columns) > 2:
572+
data = data.iloc[:, :2]
573+
571574
# Calculate total number of votes cast by each included participant
572575
# Vote values: 1 = agree, -1 = disagree, 0 = pass, NaN = no vote
573576
# Count all non-NaN values (any vote cast) for the included participants only
@@ -689,11 +692,8 @@ def save_projections_json(
689692
# If it's a DataFrame, get the values
690693
X_clustered = filter_output.values
691694

692-
# Ensure we have 2D coordinates (take first 2 dimensions if more)
693-
if X_clustered.shape[1] > 2:
694-
X_clustered = X_clustered[:, :2]
695-
696-
# Create the format: [[participant_id, [x, y]], ...]
695+
# Save all components to disk (don't truncate to 2D)
696+
# Create the format: [[participant_id, [comp1, comp2, comp3, ...]], ...]
697697
X_with_ids = []
698698
for i, participant_id in enumerate(included_participant_ids):
699699
coords = X_clustered[i].tolist()

src/kedro_polis_classic/pipelines/polis_legacy/nodes.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ def _create_scatter_plot(
145145
fig.update_layout(
146146
title=title,
147147
scene=dict(
148-
xaxis_title=f"{x_col.upper()} Component",
149-
yaxis_title=f"{y_col.upper()} Component",
150-
zaxis_title=f"{z_col.upper()} Component",
148+
xaxis_title=f"{str(x_col).upper()}",
149+
yaxis_title=f"{str(y_col).upper()}",
150+
zaxis_title=f"{str(z_col).upper()}",
151151
),
152152
width=800,
153153
height=600,
@@ -174,8 +174,8 @@ def _create_scatter_plot(
174174

175175
fig.update_layout(
176176
title=title,
177-
xaxis_title=f"{x_col.upper()} Component",
178-
yaxis_title=f"{y_col.upper()} Component",
177+
xaxis_title=f"{str(x_col).upper()}",
178+
yaxis_title=f"{str(y_col).upper()}",
179179
width=800,
180180
height=600,
181181
plot_bgcolor="white",
@@ -289,11 +289,7 @@ def reduce_with_pca(
289289
components = pca.fit_transform(imputed_vote_matrix)
290290

291291
# Create column names based on number of components
292-
DIMENSION_COLS = ["x", "y", "z"]
293-
if n_components <= 3:
294-
column_names = DIMENSION_COLS[:n_components]
295-
else:
296-
column_names = [f"PC{i + 1}" for i in range(n_components)]
292+
column_names = [f"comp{i + 1}" for i in range(n_components)]
297293

298294
return pd.DataFrame(
299295
components, index=imputed_vote_matrix.index, columns=pd.Index(column_names)

0 commit comments

Comments
 (0)