Skip to content

Commit c18aad4

Browse files
authored
[Cherry-Pick-Main][Server] Catch another edge case for compute_feature_stats (#48)
1 parent 8b94c17 commit c18aad4

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

func_tests/pytest_tests/pipeline_tests/test_Loading_a_Trained_Neuron_Array.py

+2
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def test_LoadTrainedNeuron(dsk_random_project, DataDir):
123123

124124
results, stats = dsk.pipeline.execute()
125125

126+
print(results)
127+
126128
results.summarize()
127129
model = results.configurations[0].models[0]
128130
assert model.neurons == neuron_array

src/server/datamanager/sandbox.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ def calculate_feature_stats(feature_data, feature_table, label_column, sandbox_u
143143
if not isinstance(feature_data, DataFrame):
144144
return {}
145145

146-
feature_statistics = model_generator.compute_feature_stats(
147-
feature_data[selected_feature_cols], label_in_df=False
148-
)
146+
feature_statistics = {}
147+
if label_column and label_column in feature_data.columns:
148+
selected_feature_cols.append(label_column)
149+
feature_statistics = model_generator.compute_feature_stats(
150+
feature_data[selected_feature_cols]
151+
)
149152

150153
feature_summary = (
151154
selected_features.where(notnull(selected_features), NA)

src/server/library/model_generators/model_generator.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -280,29 +280,27 @@ def compute_outliers(d):
280280
return outliers
281281

282282

283-
def compute_feature_stats(df, label_in_df=True):
283+
def compute_feature_stats(df):
284284
"""This assumes that the last column in the dataframe is the label column"""
285285

286-
if label_in_df:
287-
columns = df.columns[:-1]
288-
else:
289-
columns = df.columns
286+
label = df.columns[-1]
287+
features_names = df.columns[:-1]
290288

291-
g = df.groupby(columns)
289+
g = df.groupby(label)
292290

293291
M = {}
294292
for k, v in g.groups.items():
295293
M[k] = (
296-
g.get_group(k)[columns]
294+
g.get_group(k)[features_names]
297295
.describe(percentiles=[0.045, 0.25, 0.5, 0.75, 0.955])
298296
.round(2)
299297
.fillna(0)
300298
.to_dict()
301299
)
302300

303-
outliers = compute_outliers(g.get_group(k)[columns])
301+
outliers = compute_outliers(g.get_group(k)[features_names])
304302
for feature in M[k].keys():
305-
M[k][feature]["median"] = g.get_group(k)[columns][feature].median()
303+
M[k][feature]["median"] = g.get_group(k)[features_names][feature].median()
306304
M[k][feature]["outlier"] = outliers[feature]
307305

308306
l = {k: {} for k in M[next(iter(M))].keys()}

0 commit comments

Comments
 (0)