-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathArtifactRemoverTransformer.py
45 lines (35 loc) · 1.76 KB
/
ArtifactRemoverTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import joblib
from sklearn.base import BaseEstimator, TransformerMixin
from artifact_detection_model.transformer.DoNotReplaceArtifacts import DoNotReplaceArtifacts
from artifact_detection_model.transformer.ReplaceButKeepExceptionNames import ReplaceButKeepExceptionNames
from artifact_detection_model.transformer.SimpleReplace import SimpleReplace
from file_anchor import root_dir
SIMPLE = 'simple'
KEEP_EXCEPTION_NAMES = 'keep_exception_names'
DO_NOT_REPLACE = 'no_replacements'
replacement_strategies = {SIMPLE: SimpleReplace(),
KEEP_EXCEPTION_NAMES: ReplaceButKeepExceptionNames(),
DO_NOT_REPLACE: DoNotReplaceArtifacts()}
class ArtifactRemoverTransformer(BaseEstimator, TransformerMixin):
def __init__(self, classiefier, replacement_strategy=SIMPLE):
self.replacement_strategy = replacement_strategy
self.classifier = classiefier
def fit(self, X, y=None):
return self
def transform(self, X):
if self.replacement_strategy == DO_NOT_REPLACE:
return X
return [self.predict_and_remove(i) for i in X]
def predict_and_remove(self, issue):
replacement_strategy = replacement_strategies[self.replacement_strategy]
prediction = self.classifier.predict(issue.splitlines())
text_indices = [i for i, e in enumerate(prediction) if e == 1]
cleaned_issue = []
for i in range(0, len(issue.splitlines())):
if i in text_indices:
cleaned_issue.append(issue.splitlines()[i])
else:
replacement = replacement_strategy.get_replacement(issue.splitlines()[i])
if replacement.strip():
cleaned_issue.append(replacement)
return '\n'.join(cleaned_issue)