-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathtrain.py
More file actions
101 lines (84 loc) · 3.16 KB
/
train.py
File metadata and controls
101 lines (84 loc) · 3.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright 2021 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Covertype Classifier trainer script."""
import os
import pickle
import subprocess
import sys
import fire
import hypertune
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
AIP_MODEL_DIR = os.environ["AIP_MODEL_DIR"]
MODEL_FILENAME = "model.pkl"
def train_evaluate(
training_dataset_path, validation_dataset_path, alpha, max_iter, hptune
):
"""Trains the Covertype Classifier model."""
df_train = pd.read_csv(training_dataset_path)
df_validation = pd.read_csv(validation_dataset_path)
if not hptune:
df_train = pd.concat([df_train, df_validation])
numeric_feature_indexes = slice(0, 10)
categorical_feature_indexes = slice(10, 12)
preprocessor = ColumnTransformer(
transformers=[
("num", StandardScaler(), numeric_feature_indexes),
("cat", OneHotEncoder(), categorical_feature_indexes),
]
)
pipeline = Pipeline(
[
("preprocessor", preprocessor),
("classifier", SGDClassifier(loss="log_loss")),
]
)
num_features_type_map = {
feature: "float64"
for feature in df_train.columns[numeric_feature_indexes]
}
df_train = df_train.astype(num_features_type_map)
df_validation = df_validation.astype(num_features_type_map)
print(f"Starting training: alpha={alpha}, max_iter={max_iter}")
# pylint: disable-next=invalid-name
X_train = df_train.drop("Cover_Type", axis=1)
y_train = df_train["Cover_Type"]
pipeline.set_params(
classifier__alpha=alpha, classifier__max_iter=int(max_iter)
)
pipeline.fit(X_train, y_train)
if hptune:
# pylint: disable-next=invalid-name
X_validation = df_validation.drop("Cover_Type", axis=1)
y_validation = df_validation["Cover_Type"]
accuracy = pipeline.score(X_validation, y_validation)
print(f"Model accuracy: {accuracy}")
# Log it with hypertune
hpt = hypertune.HyperTune()
hpt.report_hyperparameter_tuning_metric(
hyperparameter_metric_tag="accuracy", metric_value=accuracy
)
# Save the model
if not hptune:
with open(MODEL_FILENAME, "wb") as model_file:
pickle.dump(pipeline, model_file)
subprocess.check_call(
["gsutil", "cp", MODEL_FILENAME, AIP_MODEL_DIR], stderr=sys.stdout
)
print(f"Saved model in: {AIP_MODEL_DIR}")
if __name__ == "__main__":
fire.Fire(train_evaluate)