This project uses a machine learning approach to predict the survival of patients with heart failure using clinical and demographic features.
The dataset is sourced from Kaggle and contains 12 features:
age: Age of the patientanaemia: Presence of anaemia (1 = yes, 0 = no)creatinine_phosphokinase: Level of CPK enzyme in the blood (mcg/L)diabetes: Presence of diabetes (1 = yes, 0 = no)ejection_fraction: Percentage of blood leaving the heart at each contractionhigh_blood_pressure: Presence of hypertension (1 = yes, 0 = no)platelets: Platelet count (kiloplatelets/mL)serum_creatinine: Level of serum creatinine (mg/dL)serum_sodium: Level of serum sodium (mEq/L)sex: Gender of the patient (1 = male, 0 = female)smoking: Smoking status (1 = yes, 0 = no)time: Follow-up period (days)death_event: Target variable (1 = death, 0 = survival)
Cardiovascular diseases (CVDs) are the leading cause of death globally, responsible for an estimated 17.9 million deaths annually (31% of all deaths). Heart failure is a major outcome of CVDs, and predicting mortality risk can aid early intervention.
The project implements a neural network classifier using TensorFlow/Keras to predict patient survival based on clinical data. Key steps include:
- Data Loading and Inspection: Load CSV dataset and inspect structure.
- Data Preprocessing:
- One-hot encode categorical variables.
- Scale numeric features using
StandardScaler. - Encode target labels for classification.
- Model Architecture:
- Input layer matching the number of features.
- Hidden layer with 12 neurons and ReLU activation.
- Output layer with 2 neurons (YES/NO) and softmax activation.
- Training:
- Loss function: Categorical Crossentropy.
- Optimizer: Adam.
- Metrics: Accuracy.
- Epochs: 100, Batch size: 16.
- Evaluation:
- Classification report with precision, recall, f1-score.
- Accuracy on the test set.
| Class | Precision | Recall | F1-score | Support |
|---|---|---|---|---|
| 0 | 0.74 | 0.91 | 0.82 | 44 |
| 1 | 0.81 | 0.55 | 0.65 | 31 |
| Accuracy | - | - | 0.76 | 75 |
| Macro Avg | 0.78 | 0.73 | 0.74 | 75 |
| Weighted Avg | 0.77 | 0.76 | 0.75 | 75 |
- Python 3.x
- Pandas
- NumPy
- scikit-learn
- TensorFlow/Keras
pip install pandas numpy scikit-learn tensorflow