Skip to content

Commit 4cec393

Browse files
committed
added training classifier
1 parent db4cdfa commit 4cec393

File tree

13 files changed

+490
-67
lines changed

13 files changed

+490
-67
lines changed

app.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from datetime import timezone
99
from db import get_database
1010
import time
11+
from functools import wraps
1112

1213
app = Flask(__name__, static_folder="static", template_folder="templates")
13-
app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'default-secret')
1414
app.db = get_database()
15+
SECRET_KEY = os.getenv("POSTS_SECRET_KEY", "******")
1516

1617
# Logging
1718
app.logger = get_logger("app")
@@ -24,6 +25,15 @@
2425
# test log
2526
register_middlewares(app)
2627

28+
def require_secret_key(f):
29+
@wraps(f)
30+
def decorated(*args, **kwargs):
31+
key = request.headers.get("X-SECRET-KEY")
32+
if key != SECRET_KEY:
33+
return jsonify({"status": "error", "message": "Unauthorized"}), 401
34+
return f(*args, **kwargs)
35+
return decorated
36+
2737
@app.route("/")
2838
def index():
2939
return render_template("index.html", time=time.time)
@@ -48,6 +58,10 @@ def subscribe():
4858
techteams = data.get('techteams')
4959
individuals = data.get('individuals')
5060
communities = data.get('communities')
61+
frequency = data.get('frquency')
62+
63+
if not frequency:
64+
frequency = 3
5165

5266
if not email or not topic or (not techteams and not individuals and not communities):
5367
return jsonify({"status": "error", "message": "Missing email or topic or publisher"
@@ -68,7 +82,7 @@ def subscribe():
6882

6983
existing_subscriptions = app.db.get_subscriptions_by_email(conn, email)
7084
if not any(sub["publisher"]["id"] == publisher["id"] and sub["topic"] == topic for sub in existing_subscriptions):
71-
app.db.add_subscription(conn, email, topic, publisher['id'])
85+
app.db.add_subscription(conn, email, topic, publisher['id'], frequency=frequency)
7286

7387
conn.commit()
7488
return jsonify({
@@ -163,6 +177,47 @@ def robots_txt():
163177
def sitemap_xml():
164178
return send_from_directory(app.static_folder, "sitemap.xml")
165179

180+
@app.route("/postview.html")
181+
def postview():
182+
return send_from_directory(app.template_folder, "posts.html")
183+
184+
@app.route("/posts", methods=["GET"])
185+
@require_secret_key
186+
def get_posts():
187+
conn = app.db.get_connection()
188+
try:
189+
posts = app.db.get_posts(conn)
190+
result = []
191+
for post in posts:
192+
result.append({
193+
"id": post["id"],
194+
"url": post["url"],
195+
"title": post["title"],
196+
"topic": post["topic"],
197+
"labelled": post['labelled']
198+
})
199+
return jsonify(result)
200+
finally:
201+
conn.close()
202+
203+
@app.route("/posts/<int:post_id>", methods=["PATCH"])
204+
def update_post(post_id):
205+
key = request.headers.get("X-SECRET-KEY")
206+
if key != SECRET_KEY:
207+
return jsonify({"status": "error", "message": "Unauthorized"}), 401
208+
209+
data = request.get_json()
210+
topic = data.get("topic")
211+
if not topic:
212+
return jsonify({"status": "error", "message": "No topic provided"}), 400
213+
214+
conn = app.db.get_connection()
215+
try:
216+
app.db.update_post_label(conn, post_id, topic)
217+
return jsonify({"status": "success", "message": f"Post {post_id} updated"})
218+
finally:
219+
conn.close()
220+
166221
if __name__ == "__main__":
167222
if os.getenv("FLASK_ENV") == "Production":
168223
app.run()

classifier.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1+
# classifier_model.py
2+
import os
13
from sentence_transformers import SentenceTransformer, util
24
from db import enums
3-
import torch
5+
import pickle
6+
from logger_config import get_logger
47

5-
# 1. Load a stronger model
6-
model = SentenceTransformer('all-mpnet-base-v2')
8+
env = os.getenv('FLASK_ENV', 'development')
9+
MODEL_PATH = os.getenv("MODEL_PATH") if env == 'production' else 'data/dev/trained_classifier.pkl'
10+
CONFIDENCE_THRESHOLD = 0.7
711

8-
# 2. Concise category descriptions
12+
logger = get_logger("classifier")
13+
# Load embedding model
14+
embedding_model = SentenceTransformer('all-mpnet-base-v2')
15+
16+
# Category descriptions
917
categories = {
1018
enums.PublisherCategory.SOFTWARE_ENGINEERING.value: (
1119
"frontend, backend, APIs, microservices, databases, relational databases, cloud databases, DevOps, system design, CI/CD, containers, scalability, performance, distributed systems, mobile, UI/UX"
@@ -27,17 +35,17 @@
2735
)
2836
}
2937

30-
# 3. Encode category descriptions
38+
# Precompute embeddings for baseline
3139
category_embeddings = {
32-
cat: model.encode(desc, convert_to_tensor=True)
40+
cat: embedding_model.encode(desc, convert_to_tensor=True)
3341
for cat, desc in categories.items()
3442
}
3543

36-
# Optional: simple keyword mapping to override embeddings
44+
# Optional keyword mapping
3745
keywords_map = {
3846
enums.PublisherCategory.SOFTWARE_ENGINEERING.value: [
39-
"react", "angular", "vue", "node.js", "django", "java", "go",
40-
"microservices", "api", "devops", "kubernetes",
47+
"react", "angular", "vue", "node.js", "django", "java", "go",
48+
"microservices", "api", "devops", "kubernetes",
4149
"aurora", "rds", "cloud database", "postgresql", "mysql", "mongodb", "redis", "database"
4250
],
4351
enums.PublisherCategory.SOFTWARE_TESTING.value: [
@@ -54,41 +62,56 @@
5462
]
5563
}
5664

57-
def classify_post(post_title, tags="", content=""):
58-
"""
59-
Classify a post into a category using title + tags + first 100 chars of content.
60-
Uses embeddings similarity with optional keyword boost.
61-
"""
65+
# ===== Load trained classifier if exists =====
66+
trained_clf = None
67+
label_encoder = None
68+
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
6269

63-
# 1. Prepare text
64-
content_snippet = content[:100] if content else ""
65-
combined_text = f"Title: {post_title}. Tags: {tags}. Content: {content_snippet}".lower()
70+
if os.path.exists(MODEL_PATH):
71+
with open(MODEL_PATH, "rb") as f:
72+
trained_clf, label_encoder = pickle.load(f)
73+
logger.info(f"[Classifier] Loaded trained classifier from {MODEL_PATH}")
6674

67-
# 2. Encode input
68-
text_embedding = model.encode(combined_text, convert_to_tensor=True)
75+
# ===== Baseline classifier =====
76+
def classify_with_embeddings(title, tags="", content=""):
77+
content_snippet = content[:100] if content else ""
78+
combined_text = f"Title: {title}. Tags: {tags}. Content: {content_snippet}".lower()
79+
text_embedding = embedding_model.encode(combined_text, convert_to_tensor=True)
6980

70-
# 3. Compute similarity
71-
scores = {
72-
cat: util.cos_sim(text_embedding, emb).item()
73-
for cat, emb in category_embeddings.items()
74-
}
81+
scores = {cat: util.cos_sim(text_embedding, emb).item()
82+
for cat, emb in category_embeddings.items()}
7583

76-
# 4. Keyword boost: add 0.1 if a keyword exists in title/tags/content
7784
combined_lower = combined_text.lower()
7885
for cat, kw_list in keywords_map.items():
7986
for kw in kw_list:
8087
if kw in combined_lower:
8188
scores[cat] += 0.1
8289
break
8390

84-
# 5. Assign category with highest similarity
8591
best_cat = max(scores, key=scores.get)
86-
87-
# 6. Adaptive fallback: check relative score
8892
sorted_scores = sorted(scores.values(), reverse=True)
8993
top_score = sorted_scores[0]
9094
second_score = sorted_scores[1] if len(sorted_scores) > 1 else 0.0
9195
if top_score < 0.25 or (top_score - second_score) < 0.05:
9296
return enums.PublisherCategory.GENERAL.value
93-
9497
return best_cat
98+
99+
# ===== Unified classifier =====
100+
def classify_post(title, tags="", content=""):
101+
global trained_clf, label_encoder
102+
103+
if trained_clf and label_encoder:
104+
logger.info("Attempt to use trained classifier")
105+
106+
text_embedding = embedding_model.encode(f"Title: {title}. Tags: {tags}. Content: {content[:100]}")
107+
pred_proba = trained_clf.predict_proba([text_embedding])[0]
108+
max_prob = pred_proba.max()
109+
if max_prob >= CONFIDENCE_THRESHOLD:
110+
logger.info("Good confidence score with trained classifier")
111+
pred_label = trained_clf.predict([text_embedding])[0]
112+
return label_encoder.inverse_transform([pred_label])[0]
113+
else:
114+
logger.info(f"Fallback to normal without trained mode due to low confidence: {max_prob}")
115+
116+
# fallback
117+
return classify_with_embeddings(title, tags, content)

db/sqlite.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,32 @@ def get_notifications_by_email(self, conn, email):
230230
""", (email,))
231231
rows = c.fetchall()
232232
return [dict(row) for row in rows]
233+
234+
def get_active_notifications_by_email_and_url(self, conn, email, url):
235+
c = conn.cursor()
236+
c.execute("""
237+
SELECT *
238+
FROM notifications
239+
WHERE email = ? and post_url = ? and deleted=0
240+
""", (email,url))
241+
row = c.fetchone()
242+
return dict(row) if row else None
233243

234244
def add_notification(self, conn, email, heading, style_version, post_url, post_title, maturity_date):
235245
logger.info(f"Adding notification: {email}, type: {post_title}")
236-
c = conn.cursor()
237-
c.execute("""
238-
INSERT INTO notifications (email, heading, style_version, post_url, post_title, maturity_date)
239-
VALUES (?, ?, ?, ?, ?, ?)
240-
""", (email, heading, style_version, post_url, post_title, maturity_date))
241-
logger.info("notification added successfully!")
246+
247+
notf = self.get_active_notifications_by_email_and_url(conn, email, post_url)
248+
249+
if not notf:
250+
c = conn.cursor()
251+
c.execute("""
252+
INSERT INTO notifications (email, heading, style_version, post_url, post_title, maturity_date)
253+
VALUES (?, ?, ?, ?, ?, ?)
254+
""", (email, heading, style_version, post_url, post_title, maturity_date))
255+
logger.info("notification added successfully!")
256+
else:
257+
logger.info("notification already existed!")
258+
242259

243260
def delete_notification(self, conn, email, post_url):
244261
logger.info(f"Deleting notification: {email}, url: {post_url}")

handlers/base.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from datetime import datetime
21
import feedparser
32
from datetime import timezone
43
import ssl
@@ -8,7 +7,7 @@
87

98
HEADERS = {'User-Agent': 'Mozilla/5.0'}
109

11-
logger = get_logger("handlers")
10+
logger = get_logger("base-handler")
1211

1312
class BaseScraper:
1413
def get_feed_url(self):
@@ -39,25 +38,28 @@ def search_blog_posts(self, category, last_scan_time):
3938

4039
try:
4140
# Parse published date using the correct format
42-
published = parsedate_to_datetime(entry.published)
43-
except ValueError as e:
44-
logger.error(f"Date parse error: {entry.published} -> {e}")
45-
continue
46-
47-
if last_scan_time.tzinfo is None:
48-
last_scan_time = last_scan_time.replace(tzinfo=timezone.utc)
49-
if published <= last_scan_time:
50-
logger.debug(f"Skipping {entry.title}: article published on {published} before last scan time: {last_scan_time}")
51-
continue
52-
53-
# full_content = entry.content[0].value if entry.content else ""
54-
# content = full_content[:100] # truncate to first 100 chars
55-
56-
matching_posts.append({
57-
"title": entry.title,
58-
"url": entry.link,
59-
"published": published.isoformat(),
60-
"tags": categories
61-
})
41+
published = None
42+
if hasattr(entry, "published"):
43+
published = parsedate_to_datetime(entry.published)
44+
elif hasattr(entry, "updated"):
45+
published = parsedate_to_datetime(entry.updated)
46+
47+
if published is None:
48+
published = self.get_date_from_url(entry)
49+
50+
if last_scan_time.tzinfo is None:
51+
last_scan_time = last_scan_time.replace(tzinfo=timezone.utc)
52+
if published <= last_scan_time:
53+
logger.debug(f"Skipping {entry.title}: article published on {published} before last scan time: {last_scan_time}")
54+
continue
55+
56+
matching_posts.append({
57+
"title": entry.title,
58+
"url": entry.link,
59+
"published": published.isoformat(),
60+
"tags": categories
61+
})
62+
except Exception:
63+
logger.exception(f"Date parse error: {entry}")
6264

6365
return matching_posts

handlers/google.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,70 @@
11
from .base import BaseScraper
2+
from datetime import datetime
3+
from datetime import timezone
4+
from logger_config import get_logger
5+
import requests
6+
from bs4 import BeautifulSoup
27

3-
BASE_URL = "https://blog.google/rss/"
8+
BASE_URL = "https://developers.googleblog.com/rss"
9+
10+
logger = get_logger("google-handler")
411

512
class GoogleScraper(BaseScraper):
6-
13+
714
def get_feed_url(self):
8-
return BASE_URL
15+
return BASE_URL
16+
17+
def parse_google_blog_date(self, date_str: str):
18+
"""
19+
Parse Google Developers Blog dates like:
20+
- 'AUG. 18, 2025'
21+
- 'JULY 24, 2025'
22+
Returns a timezone-aware datetime in UTC.
23+
"""
24+
if not date_str:
25+
return None
26+
27+
# Clean string: remove dot, normalize case
28+
clean_str = date_str.replace('.', '').title() # 'Aug 18, 2025' or 'July 24, 2025'
29+
30+
# Try full month name first (%B), then abbreviated (%b)
31+
for fmt in ("%B %d, %Y", "%b %d, %Y"):
32+
try:
33+
dt = datetime.strptime(clean_str, fmt)
34+
return dt.replace(tzinfo=timezone.utc)
35+
except ValueError:
36+
continue
37+
38+
logger.warning(f"Unable to parse date from Google blog: '{date_str}'")
39+
return None
40+
41+
def get_date_from_url(self, entry):
42+
"""Fetch published date from Google Developers Blog post HTML."""
43+
try:
44+
url = entry.link
45+
title = entry.title
46+
resp = requests.get(url, timeout=5)
47+
if resp.status_code != 200:
48+
logger.warning(f"Non-200 response for {title}: {resp.status_code}")
49+
return None
50+
51+
soup = BeautifulSoup(resp.text, "html.parser")
52+
53+
# Target the div with class "published-date glue-font-weight-medium"
54+
div_date = soup.find("div", class_="published-date glue-font-weight-medium")
55+
if div_date and div_date.text.strip():
56+
published_text = div_date.text.strip()
57+
logger.info(f"Published date for title {title}: {published_text}")
58+
return self.parse_google_blog_date(published_text)
59+
60+
# fallback to HTTP Last-Modified header
61+
last_mod = resp.headers.get("Last-Modified")
62+
if last_mod:
63+
logger.info(f"Published date for title {title}: {last_mod}")
64+
return self.parse_google_blog_date(last_mod)
65+
66+
except Exception:
67+
logger.exception(f"Failed to get published date from {url}")
68+
69+
logger.info(f"No published date found for {url}")
70+
return None

0 commit comments

Comments
 (0)