Skip to content

Commit ff70392

Browse files
committed
ml: create ml module; add scikit spam model
1 parent a470c63 commit ff70392

File tree

7 files changed

+215
-0
lines changed

7 files changed

+215
-0
lines changed

site/setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ invenio_base.apps =
4848
zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration
4949
invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE
5050
zenodo_rdm_stats = zenodo_rdm.stats.ext:ZenodoStats
51+
zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML
5152
invenio_base.api_apps =
5253
zenodo_rdm_legacy = zenodo_rdm.legacy.ext:ZenodoLegacy
5354
profiler = zenodo_rdm.profiler:Profiler
5455
zenodo_rdm_metrics = zenodo_rdm.metrics.ext:ZenodoMetrics
5556
zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration
5657
invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE
58+
zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML
5759
invenio_base.api_blueprints =
5860
zenodo_rdm_legacy = zenodo_rdm.legacy.views:blueprint
5961
zenodo_rdm_legacy_records = zenodo_rdm.legacy.views:create_legacy_records_bp

site/zenodo_rdm/ml/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# Zenodo-RDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
"""Machine learning module."""

site/zenodo_rdm/ml/base.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# Zenodo-RDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
"""Base class for ML models."""
8+
9+
10+
class MLModel:
11+
"""Base class for ML models."""
12+
13+
def __init__(self, version=None, **kwargs):
14+
"""Constructor."""
15+
self.version = version
16+
17+
def process(self, data, preprocess=None, postprocess=None, raise_exc=True):
18+
"""Pipeline function to call pre/post process with predict."""
19+
try:
20+
preprocessor = preprocess or self.preprocess
21+
postprocessor = postprocess or self.postprocess
22+
23+
preprocessed = preprocessor(data)
24+
prediction = self.predict(preprocessed)
25+
return postprocessor(prediction)
26+
except Exception as e:
27+
if raise_exc:
28+
raise e
29+
return None
30+
31+
def predict(self, data):
32+
"""Predict method to be implemented by subclass."""
33+
raise NotImplementedError()
34+
35+
def preprocess(self, data):
36+
"""Preprocess data."""
37+
return data
38+
39+
def postprocess(self, data):
40+
"""Postprocess data."""
41+
return data

site/zenodo_rdm/ml/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# ZenodoRDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
8+
"""Machine learning config."""
9+
10+
from .models import SpamDetectorScikit
11+
12+
ML_MODELS = {
13+
"spam_scikit": SpamDetectorScikit,
14+
}
15+
"""Machine learning models."""
16+
17+
# NOTE Model URL and model host need to be formattable strings for the model name.
18+
ML_KUBEFLOW_MODEL_URL = "CHANGE-{0}-ME"
19+
ML_KUBEFLOW_MODEL_HOST = "{0}-CHANGE"
20+
ML_KUBEFLOW_TOKEN = "CHANGE SECRET"
21+
"""Kubeflow connection config."""

site/zenodo_rdm/ml/ext.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# ZenodoRDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
8+
"""ZenodoRDM machine learning module."""
9+
10+
from flask import current_app
11+
12+
from . import config
13+
14+
15+
class ZenodoML:
16+
"""Zenodo machine learning extension."""
17+
18+
def __init__(self, app=None):
19+
"""Extension initialization."""
20+
if app:
21+
self.init_app(app)
22+
23+
@staticmethod
24+
def init_config(app):
25+
"""Initialize configuration."""
26+
for k in dir(config):
27+
if k.startswith("ML_"):
28+
app.config.setdefault(k, getattr(config, k))
29+
30+
def init_app(self, app):
31+
"""Flask application initialization."""
32+
self.init_config(app)
33+
app.extensions["zenodo-ml"] = self
34+
35+
def _parse_model_name_version(self, model):
36+
"""Parse model name and version."""
37+
vals = model.rsplit(":")
38+
version = vals[1] if len(vals) > 1 else None
39+
return vals[0], version
40+
41+
def models(self, model, **kwargs):
42+
"""Return model based on model name."""
43+
models = current_app.config.get("ML_MODELS", {})
44+
model_name, version = self._parse_model_name_version(model)
45+
46+
if model_name not in models:
47+
raise ValueError("Model not found/registered.")
48+
49+
return models[model_name](version=version, **kwargs)

site/zenodo_rdm/ml/models.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# ZenodoRDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
"""Model definitions."""
8+
9+
10+
import json
11+
import string
12+
13+
import requests
14+
from bs4 import BeautifulSoup
15+
from flask import current_app
16+
17+
from .base import MLModel
18+
19+
20+
class SpamDetectorScikit(MLModel):
21+
MODEL_NAME = "sklearn-spam"
22+
MAX_WORDS = 4000
23+
24+
def __init__(self, version, **kwargs):
25+
"""Constructor. Makes version required."""
26+
super().__init__(version, **kwargs)
27+
28+
def preprocess(self, data):
29+
"""Preprocess data.
30+
31+
Parse HTML, remove punctuation and truncate to max chars.
32+
"""
33+
text = BeautifulSoup(data, "html.parser").get_text()
34+
trans_table = str.maketrans(string.punctuation, " " * len(string.punctuation))
35+
parts = text.translate(trans_table).lower().strip().split(" ")
36+
if len(parts) >= self.MAX_WORDS:
37+
parts = parts[: self.MAX_WORDS]
38+
return " ".join(parts)
39+
40+
def postprocess(self, data):
41+
"""Postprocess data.
42+
43+
Gives spam and ham probability.
44+
"""
45+
result = {
46+
"spam": data["outputs"][0]["data"][0],
47+
"ham": data["outputs"][0]["data"][1],
48+
}
49+
return result
50+
51+
def _send_request_kubeflow(self, data):
52+
"""Send predict request to Kubeflow."""
53+
payload = {
54+
"inputs": [
55+
{
56+
"name": "input-0",
57+
"shape": [1],
58+
"datatype": "BYTES",
59+
"data": [f"{data}"],
60+
}
61+
]
62+
}
63+
model_ref = self.MODEL_NAME + "-" + self.version
64+
url = current_app.config.get("ML_KUBEFLOW_MODEL_URL").format(model_ref)
65+
host = current_app.config.get("ML_KUBEFLOW_MODEL_HOST").format(model_ref)
66+
access_token = current_app.config.get("ML_KUBEFLOW_TOKEN")
67+
r = requests.post(
68+
url,
69+
headers={
70+
"Authorization": f"Bearer {access_token}",
71+
"Content-Type": "application/json",
72+
"Host": host,
73+
},
74+
json=payload,
75+
)
76+
if r.status_code != 200:
77+
raise requests.RequestException("Prediction was not successful.", request=r)
78+
return json.loads(r.text)
79+
80+
def predict(self, data):
81+
"""Get prediction from model."""
82+
prediction = self._send_request_kubeflow(data)
83+
return prediction

site/zenodo_rdm/ml/proxies.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright (C) 2024 CERN.
4+
#
5+
# ZenodoRDM is free software; you can redistribute it and/or modify
6+
# it under the terms of the MIT License; see LICENSE file for more details.
7+
"""Proxy objects for easier access to application objects."""
8+
9+
from flask import current_app
10+
from werkzeug.local import LocalProxy
11+
12+
current_ml_models = LocalProxy(lambda: current_app.extensions["zenodo-ml"])

0 commit comments

Comments
 (0)