Skip to content

Commit 6b2e150

Browse files
committed
ml: create ml module; add scikit spam model
1 parent a470c63 commit 6b2e150

File tree

7 files changed

+217
-0
lines changed

7 files changed

+217
-0
lines changed

site/setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ invenio_base.apps =
4848
zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration
4949
invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE
5050
zenodo_rdm_stats = zenodo_rdm.stats.ext:ZenodoStats
51+
zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML
5152
invenio_base.api_apps =
5253
zenodo_rdm_legacy = zenodo_rdm.legacy.ext:ZenodoLegacy
5354
profiler = zenodo_rdm.profiler:Profiler
5455
zenodo_rdm_metrics = zenodo_rdm.metrics.ext:ZenodoMetrics
5556
zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration
5657
invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE
58+
zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML
5759
invenio_base.api_blueprints =
5860
zenodo_rdm_legacy = zenodo_rdm.legacy.views:blueprint
5961
zenodo_rdm_legacy_records = zenodo_rdm.legacy.views:create_legacy_records_bp

site/zenodo_rdm/ml/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# Zenodo-RDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
"""Machine learning module."""

site/zenodo_rdm/ml/base.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# Zenodo-RDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
"""Base class for ML models."""
8+
9+
10+
class MLModel:
11+
"""Base class for ML models."""
12+
13+
def __init__(self, version=None, **kwargs):
14+
"""Constructor."""
15+
self.version = version
16+
17+
def process(self, data, preprocess=None, postprocess=None, raise_exc=True):
18+
"""Pipeline function to call pre/post process with predict."""
19+
try:
20+
preprocessor = preprocess or self.preprocess
21+
postprocessor = postprocess or self.postprocess
22+
23+
preprocessed = preprocessor(data)
24+
prediction = self.predict(preprocessed)
25+
return postprocessor(prediction)
26+
except Exception as e:
27+
if raise_exc:
28+
raise e
29+
return None
30+
31+
def predict(self, data):
32+
"""Predict method to be implemented by subclass."""
33+
raise NotImplementedError()
34+
35+
def preprocess(self, data):
36+
"""Preprocess data."""
37+
return data
38+
39+
def postprocess(self, data):
40+
"""Postprocess data."""
41+
return data

site/zenodo_rdm/ml/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# ZenodoRDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
8+
"""Machine learning config."""
9+
10+
from .models import SpamDetectorScikit
11+
12+
ML_MODELS = {
13+
"spam_scikit": SpamDetectorScikit,
14+
}
15+
"""Machine learning models."""
16+
17+
# NOTE Model URL and model host need to be formattable strings for the model name.
18+
ML_KUBEFLOW_MODEL_URL = "CHANGE-{0}-ME"
19+
ML_KUBEFLOW_MODEL_HOST = "{0}-CHANGE"
20+
ML_KUBEFLOW_TOKEN = "CHANGE SECRET"
21+
"""Kubeflow connection config."""

site/zenodo_rdm/ml/ext.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# ZenodoRDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
8+
"""ZenodoRDM machine learning module."""
9+
10+
from flask import current_app
11+
12+
from . import config
13+
14+
15+
class ZenodoML:
16+
"""Zenodo machine learning extension."""
17+
18+
def __init__(self, app=None):
19+
"""Extension initialization."""
20+
if app:
21+
self.init_app(app)
22+
23+
@staticmethod
24+
def init_config(app):
25+
"""Initialize configuration."""
26+
for k in dir(config):
27+
if k.startswith("ML_"):
28+
app.config.setdefault(k, getattr(config, k))
29+
30+
def init_app(self, app):
31+
"""Flask application initialization."""
32+
self.init_config(app)
33+
app.extensions["zenodo-ml"] = self
34+
35+
def _parse_model_name_version(self, model):
36+
"""Parse model name and version."""
37+
vals = model.rsplit(":")
38+
version = vals[1] if len(vals) > 1 else None
39+
return vals[0], version
40+
41+
def models(self, model, **kwargs):
42+
"""Return model based on model name."""
43+
models = current_app.config.get("ML_MODELS", {})
44+
model_name, version = self._parse_model_name_version(model)
45+
46+
if model_name not in models:
47+
raise ValueError("Model not found/registered.")
48+
49+
return models[model_name](version=version, **kwargs)

site/zenodo_rdm/ml/models.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# ZenodoRDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
"""Model definitions."""
8+
9+
10+
import json
11+
import string
12+
13+
import requests
14+
from bs4 import BeautifulSoup
15+
from flask import current_app
16+
17+
from .base import MLModel
18+
19+
20+
class SpamDetectorScikit(MLModel):
21+
"""Spam detection model based on Sklearn."""
22+
23+
MODEL_NAME = "sklearn-spam"
24+
MAX_WORDS = 4000
25+
26+
def __init__(self, version, **kwargs):
27+
"""Constructor. Makes version required."""
28+
super().__init__(version, **kwargs)
29+
30+
def preprocess(self, data):
31+
"""Preprocess data.
32+
33+
Parse HTML, remove punctuation and truncate to max chars.
34+
"""
35+
text = BeautifulSoup(data, "html.parser").get_text()
36+
trans_table = str.maketrans(string.punctuation, " " * len(string.punctuation))
37+
parts = text.translate(trans_table).lower().strip().split(" ")
38+
if len(parts) >= self.MAX_WORDS:
39+
parts = parts[: self.MAX_WORDS]
40+
return " ".join(parts)
41+
42+
def postprocess(self, data):
43+
"""Postprocess data.
44+
45+
Gives spam and ham probability.
46+
"""
47+
result = {
48+
"spam": data["outputs"][0]["data"][0],
49+
"ham": data["outputs"][0]["data"][1],
50+
}
51+
return result
52+
53+
def _send_request_kubeflow(self, data):
54+
"""Send predict request to Kubeflow."""
55+
payload = {
56+
"inputs": [
57+
{
58+
"name": "input-0",
59+
"shape": [1],
60+
"datatype": "BYTES",
61+
"data": [f"{data}"],
62+
}
63+
]
64+
}
65+
model_ref = self.MODEL_NAME + "-" + self.version
66+
url = current_app.config.get("ML_KUBEFLOW_MODEL_URL").format(model_ref)
67+
host = current_app.config.get("ML_KUBEFLOW_MODEL_HOST").format(model_ref)
68+
access_token = current_app.config.get("ML_KUBEFLOW_TOKEN")
69+
r = requests.post(
70+
url,
71+
headers={
72+
"Authorization": f"Bearer {access_token}",
73+
"Content-Type": "application/json",
74+
"Host": host,
75+
},
76+
json=payload,
77+
)
78+
if r.status_code != 200:
79+
raise requests.RequestException("Prediction was not successful.", request=r)
80+
return json.loads(r.text)
81+
82+
def predict(self, data):
83+
"""Get prediction from model."""
84+
prediction = self._send_request_kubeflow(data)
85+
return prediction

site/zenodo_rdm/ml/proxies.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# ZenodoRDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
"""Proxy objects for easier access to application objects."""
8+
9+
from flask import current_app
10+
from werkzeug.local import LocalProxy
11+
12+
current_ml_models = LocalProxy(lambda: current_app.extensions["zenodo-ml"])

0 commit comments

Comments
 (0)