Skip to content

Commit 54494e4

Browse files
variably choose which dimensions to plot
1 parent da40297 commit 54494e4

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

scikitplot/decomposition.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
properties shared by scikit-learn estimators. The specific requirements are
66
documented per function.
77
"""
8-
from __future__ import absolute_import, division, print_function, \
9-
unicode_literals
8+
from __future__ import (absolute_import, division, print_function,
9+
unicode_literals)
1010

1111
import matplotlib.pyplot as plt
1212
import numpy as np
@@ -95,6 +95,7 @@ def plot_pca_component_variance(clf, title='PCA Component Explained Variances',
9595

9696

9797
def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection',
98+
dimensions=[0, 1],
9899
biplot=False, feature_labels=None,
99100
ax=None, figsize=None, cmap='Spectral',
100101
title_fontsize="large", text_fontsize="medium",
@@ -172,31 +173,31 @@ def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection',
172173
colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, len(classes)))
173174

174175
for label, color in zip(classes, colors):
175-
ax.scatter(transformed_X[y == label, 0], transformed_X[y == label, 1],
176+
ax.scatter(transformed_X[y == label, dimensions[0]], transformed_X[y == label, dimensions[1]],
176177
alpha=0.8, lw=2, label=label, color=color)
177178

178179
if label_dots:
179-
for dot in transformed_X[y == label, 0:2]:
180+
for dot in transformed_X[y == label, dimensions]:
180181
ax.text(*dot, label)
181182

182183
if biplot:
183-
xs = transformed_X[:, 0]
184-
ys = transformed_X[:, 1]
185-
vectors = np.transpose(clf.components_[:2, :])
184+
xs = transformed_X[:, dimensions[0]]
185+
ys = transformed_X[:, dimensions[1]]
186+
vectors = np.transpose(clf.components_[dimensions, :])
186187
vectors_scaled = vectors * [xs.max(), ys.max()]
187188
for i in range(vectors.shape[0]):
188-
ax.annotate("", xy=(vectors_scaled[i, 0], vectors_scaled[i, 1]),
189+
ax.annotate("", xy=(vectors_scaled[i, dimensions[0]], vectors_scaled[i, dimensions[1]]),
189190
xycoords='data', xytext=(0, 0), textcoords='data',
190191
arrowprops={'arrowstyle': '-|>', 'ec': 'r'})
191192

192-
ax.text(vectors_scaled[i, 0] * 1.05, vectors_scaled[i, 1] * 1.05,
193+
ax.text(vectors_scaled[i, dimensions[0]] * 1.05, vectors_scaled[i, dimensions[1]] * 1.05,
193194
feature_labels[i] if feature_labels else "Variable" + str(i),
194195
color='b', fontsize=text_fontsize)
195196

196197
ax.legend(loc='best', shadow=False, scatterpoints=1,
197198
fontsize=text_fontsize)
198-
ax.set_xlabel('First Principal Component', fontsize=text_fontsize)
199-
ax.set_ylabel('Second Principal Component', fontsize=text_fontsize)
199+
ax.set_xlabel(f'Principal Component {dimensions[0]+1}', fontsize=text_fontsize)
200+
ax.set_ylabel(f'Principal Component {dimension[1]+1}', fontsize=text_fontsize)
200201
ax.tick_params(labelsize=text_fontsize)
201202

202203
return ax

0 commit comments

Comments
 (0)