Skip to content

Commit bb91e54

Browse files
committed
Refactor: Modularized train_model.py for separation of concerns
1 parent 3ac9049 commit bb91e54

File tree

6 files changed

+127
-42
lines changed

6 files changed

+127
-42
lines changed

src/data_ingestion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import pandas as pd
2+
3+
def load_raw_data(data_path='data/data.csv'):
4+
"""Loads the raw dataset from the specified path."""
5+
df = pd.read_csv(data_path)
6+
return df

src/data_preprocessing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pandas as pd
2+
3+
def drop_unnecessary_columns(df):
4+
"""Drops the 'id' and 'Unnamed: 32' columns from the DataFrame."""
5+
return df.drop(['id', 'Unnamed: 32'], axis=1, errors='ignore')
6+
7+
def map_diagnosis_to_numerical(df):
8+
"""Converts the 'diagnosis' column to numerical (M=1, B=0)."""
9+
df['diagnosis'] = df['diagnosis'].map({'M': 1, 'B': 0})
10+
return df
11+
12+
def prepare_features_and_target(df):
13+
"""Prepares features (X) and target (y) from the preprocessed DataFrame."""
14+
X = df.drop('diagnosis', axis=1)
15+
y = df['diagnosis']
16+
return X, y

src/model.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

src/model_inference.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import joblib
2+
import pandas as pd
3+
import os
4+
5+
# Ensure the data_preprocessing module is accessible for FunctionTransformer if it was pickled
6+
# This might not be strictly necessary if FunctionTransformer only relies on the function definition itself
7+
# but good practice to have the context
8+
from .data_preprocessing import drop_unnecessary_columns
9+
10+
def load_pipeline(model_path='models/model.joblib'):
11+
"""Loads the trained scikit-learn pipeline."""
12+
if not os.path.exists(model_path):
13+
raise FileNotFoundError(f"Model pipeline not found at {model_path}. Please train the model first.")
14+
return joblib.load(model_path)
15+
16+
def predict(raw_data, model_path='models/model.joblib'):
17+
"""Loads the pipeline and makes a prediction on new raw data."""
18+
pipeline = load_pipeline(model_path)
19+
prediction = pipeline.predict(raw_data)
20+
return prediction
21+
22+
if __name__ == "__main__":
23+
print("This module is for inference. Please run model_training.py to train the model.")
24+
try:
25+
# Example of new raw data (single row DataFrame)
26+
sample_new_data = pd.DataFrame([{
27+
'radius_mean': 17.99, 'texture_mean': 10.38,
28+
'perimeter_mean': 122.8, 'area_mean': 1001.0, 'smoothness_mean': 0.1184,
29+
'compactness_mean': 0.2776, 'concavity_mean': 0.3001, 'concave points_mean': 0.1471,
30+
'symmetry_mean': 0.2419, 'fractal_dimension_mean': 0.07871,
31+
'radius_se': 1.095, 'texture_se': 0.9053, 'perimeter_se': 8.589, 'area_se': 153.4,
32+
'smoothness_se': 0.006399, 'compactness_se': 0.04904, 'concavity_se': 0.05373,
33+
'concave points_se': 0.01587, 'symmetry_se': 0.03003, 'fractal_dimension_se': 0.006193,
34+
'radius_worst': 25.38, 'texture_worst': 17.33, 'perimeter_worst': 184.6, 'area_worst': 2019.0,
35+
'smoothness_worst': 0.1622, 'compactness_worst': 0.6656, 'concavity_worst': 0.7119,
36+
'concave points_worst': 0.2654, 'symmetry_worst': 0.4601, 'fractal_dimension_worst': 0.1189
37+
}])
38+
prediction = predict(sample_new_data)
39+
print(f"Prediction for sample data: {prediction[0]} (0: Benign, 1: Malignant)")
40+
except FileNotFoundError as e:
41+
print(e)

src/model_training.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import pandas as pd
2+
from sklearn.model_selection import train_test_split
3+
from sklearn.metrics import accuracy_score
4+
import joblib
5+
import os
6+
7+
from .data_ingestion import load_raw_data
8+
from .data_preprocessing import map_diagnosis_to_numerical, prepare_features_and_target
9+
from .pipeline_utils import create_breast_cancer_pipeline
10+
11+
def train_and_save_pipeline(data_path='data/data.csv', model_path='models/model.joblib'):
12+
"""Orchestrates the training process: loads data, preprocesses, trains, and saves the pipeline."""
13+
# Load the raw dataset
14+
df_raw = load_raw_data(data_path)
15+
16+
# Apply diagnosis mapping before splitting features and target
17+
df_mapped = map_diagnosis_to_numerical(df_raw.copy()) # Use a copy to avoid modifying original df_raw if it's used elsewhere
18+
19+
# Prepare features (X) and target (y)
20+
X, y = prepare_features_and_target(df_mapped)
21+
22+
# Split data into train/test sets
23+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
24+
25+
# Create and train the pipeline
26+
pipeline = create_breast_cancer_pipeline()
27+
pipeline.fit(X_train, y_train)
28+
29+
# Evaluate the pipeline
30+
y_pred = pipeline.predict(X_test)
31+
accuracy = accuracy_score(y_test, y_pred)
32+
print(f"Pipeline Accuracy: {accuracy:.4f}")
33+
34+
# Ensure the models directory exists
35+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
36+
37+
# Save the trained pipeline using joblib
38+
joblib.dump(pipeline, model_path)
39+
print(f"Trained pipeline saved to {model_path}")
40+
41+
if __name__ == "__main__":
42+
train_and_save_pipeline()

src/pipeline_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from sklearn.pipeline import Pipeline
2+
from sklearn.preprocessing import FunctionTransformer
3+
from sklearn.ensemble import RandomForestClassifier
4+
import pandas as pd
5+
6+
from .data_preprocessing import drop_unnecessary_columns
7+
8+
def create_breast_cancer_pipeline():
9+
"""Creates and returns a scikit-learn pipeline for breast cancer prediction."""
10+
11+
# Define preprocessing steps
12+
preprocessing_pipeline = Pipeline([
13+
('drop_cols', FunctionTransformer(drop_unnecessary_columns, validate=False)),
14+
# Add other preprocessing steps here if needed, e.g., scaling
15+
])
16+
17+
# Combine preprocessing and model into a full pipeline
18+
full_pipeline = Pipeline([
19+
('preprocessor', preprocessing_pipeline),
20+
('classifier', RandomForestClassifier(random_state=42))
21+
])
22+
return full_pipeline

0 commit comments

Comments
 (0)