-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclustering_analysis.py
More file actions
71 lines (59 loc) · 2.02 KB
/
clustering_analysis.py
File metadata and controls
71 lines (59 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# How does the total Medicaid spending per drug compare across different companies?
# Import libraries
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import plotly.express as px
# Load the dataset
drug_df = pd.read_csv('drug_data.csv')
# Clustering features
clustering_features = [
'medicaid_spending_2018',
'medicaid_spending_2019',
'medicaid_spending_2020',
'medicaid_spending_2021',
'medicaid_spending_2022'
]
# Preprocess data for clustering
df_cluster = drug_df.dropna(subset=clustering_features)
X_cluster = df_cluster[clustering_features]
# Standardize the data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_cluster)
# Perform KMeans clustering
kmeans = KMeans(n_clusters=3, random_state=42)
df_cluster['kmeans_cluster'] = kmeans.fit_predict(X_scaled)
# Perform PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)
df_cluster['pca1'] = X_pca[:, 0]
df_cluster['pca2'] = X_pca[:, 1]
# Create the scatter plot
fig = px.scatter(df_subset,
x='pca1',
y='pca2',
color='cluster_label',
hover_data='company',
title='Comparison of Medicaid Drug Spending Per Drug Across Companies (2018–2022)',
labels={
'pca1': 'Principal Component 1',
'pca2': 'Principal Component 2',
'cluster_label': 'Cluster Label'
},
color_discrete_sequence=['#478ce6', '#f74a7e', '#37ad82'])
# Set plot size
fig.update_layout(width=1300, height=600)
# Increase marker size and decrease opacity
fig.update_traces(marker=dict(size=15, opacity=0.7))
# Update font sizes
fig.update_layout(
title_font=dict(size=24),
legend_title_font=dict(size=20),
legend_font=dict(size=16),
xaxis_title_font=dict(size=20),
yaxis_title_font=dict(size=20),
xaxis_tickfont=dict(size=16),
yaxis_tickfont=dict(size=16)
)
fig.show()