Skip to content

Commit cb78ce0

Browse files
authored
feat: support rss datasource (#13721)
### What problem does this PR solve? Supporting public RSS/Atom feed URLs as data sources for RagFlow. link #12313 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
1 parent f32a832 commit cb78ce0

File tree

11 files changed

+395
-1
lines changed

11 files changed

+395
-1
lines changed

common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class ParserType(StrEnum):
114114
class FileSource(StrEnum):
115115
LOCAL = ""
116116
KNOWLEDGEBASE = "knowledgebase"
117+
RSS = "rss"
117118
S3 = "s3"
118119
NOTION = "notion"
119120
DISCORD = "discord"

common/data_source/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"""
2525

2626
from .blob_connector import BlobStorageConnector
27+
from .rss_connector import RSSConnector
2728
from .slack_connector import SlackConnector
2829
from .gmail_connector import GmailConnector
2930
from .notion_connector import NotionConnector
@@ -55,6 +56,7 @@
5556

5657
__all__ = [
5758
"BlobStorageConnector",
59+
"RSSConnector",
5860
"SlackConnector",
5961
"GmailConnector",
6062
"NotionConnector",

common/data_source/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class BlobType(str, Enum):
4040

4141
class DocumentSource(str, Enum):
4242
"""Document sources"""
43+
RSS = "rss"
4344
S3 = "s3"
4445
NOTION = "notion"
4546
R2 = "r2"
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import hashlib
2+
import ipaddress
3+
import socket
4+
from datetime import datetime, timezone
5+
from email.utils import parsedate_to_datetime
6+
from time import struct_time
7+
from typing import Any
8+
from urllib.parse import urlparse
9+
10+
import bs4
11+
import feedparser
12+
import requests
13+
14+
from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource
15+
from common.data_source.interfaces import LoadConnector, PollConnector
16+
from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
17+
18+
19+
def _is_private_ip(ip: str) -> bool:
20+
try:
21+
ip_obj = ipaddress.ip_address(ip)
22+
return ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_loopback
23+
except ValueError:
24+
return False
25+
26+
27+
def _validate_url_no_ssrf(url: str) -> None:
28+
parsed = urlparse(url)
29+
hostname = parsed.hostname
30+
if not hostname:
31+
raise ValueError("URL must have a valid hostname")
32+
33+
try:
34+
ip = socket.gethostbyname(hostname)
35+
if _is_private_ip(ip):
36+
raise ValueError(f"URL resolves to private/internal IP address: {ip}")
37+
except socket.gaierror as e:
38+
raise ValueError(f"Failed to resolve hostname: {hostname}") from e
39+
40+
41+
class RSSConnector(LoadConnector, PollConnector):
42+
def __init__(self, feed_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None:
43+
self.feed_url = feed_url.strip()
44+
self.batch_size = batch_size
45+
self.credentials: dict[str, Any] = {}
46+
self._cached_feed: Any | None = None
47+
48+
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
49+
self.credentials = credentials or {}
50+
return None
51+
52+
def validate_connector_settings(self) -> None:
53+
self._validate_feed_url()
54+
if self.batch_size < 1:
55+
raise ValueError("batch_size must be greater than 0")
56+
self._read_feed(require_entries=True)
57+
58+
def load_from_state(self) -> GenerateDocumentsOutput:
59+
yield from self._load_entries()
60+
61+
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
62+
yield from self._load_entries(start=start, end=end)
63+
64+
def _load_entries(
65+
self,
66+
start: SecondsSinceUnixEpoch | None = None,
67+
end: SecondsSinceUnixEpoch | None = None,
68+
) -> GenerateDocumentsOutput:
69+
feed = self._read_feed(require_entries=False)
70+
batch: list[Document] = []
71+
72+
for entry in feed.entries:
73+
updated_at = self._resolve_entry_time(entry)
74+
ts = updated_at.timestamp()
75+
76+
if start is not None and ts <= start:
77+
continue
78+
if end is not None and ts > end:
79+
continue
80+
81+
batch.append(self._build_document(entry, updated_at))
82+
83+
if len(batch) >= self.batch_size:
84+
yield batch
85+
batch = []
86+
87+
if batch:
88+
yield batch
89+
90+
def _validate_feed_url(self) -> None:
91+
if not self.feed_url:
92+
raise ValueError("feed_url is required")
93+
94+
parsed = urlparse(self.feed_url)
95+
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
96+
raise ValueError("feed_url must be a valid http or https URL")
97+
98+
_validate_url_no_ssrf(self.feed_url)
99+
100+
def _read_feed(self, require_entries: bool) -> Any:
101+
if self._cached_feed is not None:
102+
if require_entries and not self._cached_feed.entries:
103+
raise ValueError("RSS feed contains no entries")
104+
return self._cached_feed
105+
106+
self._validate_feed_url()
107+
108+
response = requests.get(self.feed_url, timeout=REQUEST_TIMEOUT_SECONDS, allow_redirects=True)
109+
response.raise_for_status()
110+
111+
final_url = getattr(response, "url", self.feed_url)
112+
if final_url != self.feed_url and urlparse(final_url).hostname:
113+
_validate_url_no_ssrf(final_url)
114+
115+
feed = feedparser.parse(response.content)
116+
if getattr(feed, "bozo", False) and not feed.entries:
117+
error = getattr(feed, "bozo_exception", None)
118+
if error:
119+
raise ValueError(f"Failed to parse RSS feed: {error}") from error
120+
raise ValueError("Failed to parse RSS feed")
121+
if require_entries and not feed.entries:
122+
raise ValueError("RSS feed contains no entries")
123+
124+
self._cached_feed = feed
125+
return feed
126+
127+
def _build_document(self, entry: Any, updated_at: datetime) -> Document:
128+
link = (entry.get("link") or "").strip()
129+
title = (entry.get("title") or "").strip()
130+
stable_key = (entry.get("id") or link or title or self.feed_url).strip()
131+
semantic_identifier = title or link or stable_key
132+
content = self._build_content(entry, semantic_identifier)
133+
blob = content.encode("utf-8")
134+
135+
metadata: dict[str, Any] = {"feed_url": self.feed_url}
136+
if link:
137+
metadata["link"] = link
138+
if entry.get("author"):
139+
metadata["author"] = entry.get("author")
140+
141+
categories = []
142+
for tag in entry.get("tags", []):
143+
if not isinstance(tag, dict):
144+
continue
145+
term = tag.get("term")
146+
if isinstance(term, str) and term:
147+
categories.append(term)
148+
if categories:
149+
metadata["categories"] = categories
150+
151+
return Document(
152+
id=f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}",
153+
source=DocumentSource.RSS,
154+
semantic_identifier=semantic_identifier,
155+
extension=".txt",
156+
blob=blob,
157+
doc_updated_at=updated_at,
158+
size_bytes=len(blob),
159+
metadata=metadata,
160+
)
161+
162+
def _build_content(self, entry: Any, semantic_identifier: str) -> str:
163+
parts = [semantic_identifier]
164+
content_blocks = entry.get("content") or []
165+
166+
for block in content_blocks:
167+
value = block.get("value") if isinstance(block, dict) else None
168+
normalized = self._normalize_text(value)
169+
if normalized:
170+
parts.append(normalized)
171+
172+
if len(parts) == 1:
173+
fallback = entry.get("summary") or entry.get("description") or ""
174+
normalized = self._normalize_text(fallback)
175+
if normalized:
176+
parts.append(normalized)
177+
178+
return "\n\n".join(part for part in parts if part).strip()
179+
180+
def _resolve_entry_time(self, entry: Any) -> datetime:
181+
for field in ("updated_parsed", "published_parsed"):
182+
value = entry.get(field)
183+
if value:
184+
return self._struct_time_to_utc(value)
185+
186+
for field in ("updated", "published"):
187+
value = entry.get(field)
188+
if isinstance(value, str) and value.strip():
189+
try:
190+
parsed = parsedate_to_datetime(value)
191+
except (TypeError, ValueError, IndexError):
192+
continue
193+
if parsed.tzinfo is None:
194+
parsed = parsed.replace(tzinfo=timezone.utc)
195+
return parsed.astimezone(timezone.utc)
196+
197+
return datetime.now(timezone.utc)
198+
199+
@staticmethod
200+
def _normalize_text(value: Any) -> str:
201+
if not isinstance(value, str):
202+
return ""
203+
return bs4.BeautifulSoup(value, "html.parser").get_text("\n", strip=True)
204+
205+
@staticmethod
206+
def _struct_time_to_utc(value: struct_time | tuple[Any, ...]) -> datetime:
207+
dt = datetime(*value[:6], tzinfo=timezone.utc)
208+
return dt.astimezone(timezone.utc)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"editdistance==0.8.1",
3232
"elasticsearch-dsl==8.12.0",
3333
"exceptiongroup>=1.3.0,<2.0.0",
34+
"feedparser>=6.0.11,<7.0.0",
3435
"extract-msg>=0.39.0",
3536
"ffmpeg-python>=0.2.0",
3637
"flasgger>=0.9.7.1,<0.10.0",

rag/svr/sync_data_source.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from common.config_utils import show_configs
4444
from common.data_source import (
4545
BlobStorageConnector,
46+
RSSConnector,
4647
NotionConnector,
4748
DiscordConnector,
4849
GoogleDriveConnector,
@@ -243,6 +244,26 @@ class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
243244
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
244245

245246

247+
class RSS(SyncBase):
248+
SOURCE_NAME: str = FileSource.RSS
249+
250+
async def _generate(self, task: dict):
251+
self.connector = RSSConnector(
252+
feed_url=self.conf["feed_url"],
253+
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE),
254+
)
255+
self.connector.load_credentials(self.conf.get("credentials", {}))
256+
self.connector.validate_connector_settings()
257+
258+
if task["reindex"] == "1" or not task["poll_range_start"]:
259+
return self.connector.load_from_state()
260+
261+
return self.connector.poll_source(
262+
task["poll_range_start"].timestamp(),
263+
datetime.now(timezone.utc).timestamp(),
264+
)
265+
266+
246267
class Confluence(SyncBase):
247268
SOURCE_NAME: str = FileSource.CONFLUENCE
248269

@@ -1347,6 +1368,7 @@ async def _generate(self, task: dict):
13471368

13481369

13491370
func_factory = {
1371+
FileSource.RSS: RSS,
13501372
FileSource.S3: S3,
13511373
FileSource.R2: R2,
13521374
FileSource.OCI_STORAGE: OCI_STORAGE,

0 commit comments

Comments
 (0)