Skip to content

Commit 42628b0

Browse files
committed
Write clean_factory method
1 parent 2a3f493 commit 42628b0

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@
346346
raise FileNotFoundError(Fore.RED + "The plugin {} is not found. Please verify that you are using the correct file path.".format(
347347
clean_control["plot"][key]["plugin"]))
348348

349+
# Create plot directory to save plots
350+
print_info("Creating directory to save plots.")
351+
os.makedirs("data/plots", exist_ok=True)
352+
349353
print_info("Generating directive list for worker nodes.")
350354
# Generate and slice directive list that will be sent out to the workers
351355
clean_directive_list = sst.generate_clean(clean_control["plot"], ROOT_PATH + "/data/plots", ROOT_PATH + "/data")

utils/workerops/paramfactory.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,28 @@ def attack_train_factory(adver_features: List[str], model_labels: np.ndarray,
142142
return pickle_path
143143

144144

145-
def clean_factory() -> str:
145+
def clean_factory(models: List[str], plot_name: str, save_path: str, root_path: str) -> str:
146146
"""
147147
Generate parameter dictionary that will be sent out to the cleaning plugins for the cleaning stage.
148148
Save as a pickle and return a file path reference to that pickle.
149149
150150
### Parameters:
151-
- TODO
151+
:param models: List of root model directories containing data for plots.
152+
:param plot_name: Name to use for user-generated plot file.
153+
:param save_path: System location save the adversarial examples.
154+
:param root_path: Root directory of Jespipe.
152155
153156
### Returns:
154157
:return: System file path reference to pickled parameter dictionary.
155158
"""
156-
# TODO: Update this function once you revisit the cleaning stage next week
157-
pass
159+
d = dict()
160+
161+
d["model_list"] = models
162+
d["plot_name"] = plot_name
163+
d["save_path"] = save_path
164+
165+
# Establish path to file in .tmp directory and dump dictionary
166+
pickle_path = root_path + "/data/.tmp/" + str(uuid.uuid4()) + ".pkl"
167+
joblib.dump(d, pickle_path)
168+
169+
return pickle_path

0 commit comments

Comments
 (0)