-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathtest_scraper.py
More file actions
157 lines (117 loc) · 4.8 KB
/
test_scraper.py
File metadata and controls
157 lines (117 loc) · 4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from src.database_manager import DatabaseManager
from src.scraper import Scraper
class DummyDB(DatabaseManager):
def __init__(self):
pass
def __del__(self):
pass
def insert_link(self, url, visited=False):
return True
def get_unvisited_links(self):
return []
def mark_link_visited(self, url):
pass
def test_is_valid_link():
db = DummyDB()
scraper = Scraper(base_url='https://example.com', exclude_patterns=['/exclude'], db_manager=db)
assert scraper.is_valid_link('https://example.com/page')
assert not scraper.is_valid_link('https://example.com/exclude/page')
assert not scraper.is_valid_link('https://other.com/')
def test_fetch_links():
db = DummyDB()
scraper = Scraper(base_url='https://example.com', exclude_patterns=['/exclude'], db_manager=db)
html = '''<html><body>
<a href="https://example.com/page1">1</a>
<a href="/page2">2</a>
<a href="https://example.com/exclude/hidden">3</a>
</body></html>'''
links = scraper.fetch_links(url='https://example.com', html=html)
assert links == {'https://example.com/page1', 'https://example.com/page2'}
from unittest.mock import patch, MagicMock
...
@patch('os.remove')
@patch('tempfile.NamedTemporaryFile')
def test_scrape_page_parses_content_and_metadata(mock_tempfile, mock_os_remove):
# Arrange
mock_file = MagicMock()
mock_file.name = "dummy_path"
mock_tempfile.return_value.__enter__.return_value = mock_file
db = DummyDB()
scraper = Scraper(base_url='http://example.com', exclude_patterns=[], db_manager=db)
html = '<html><head><title>Test</title></head><body><p>Hello</p></body></html>'
# Act
with patch('src.scraper.MarkItDown') as mock_markdown:
mock_markdown.return_value.convert.return_value = "Hello"
content, metadata = scraper.scrape_page(html, 'http://example.com/test')
# Assert
assert 'Hello' in content
assert metadata.get('title') == 'Test'
@patch('os.remove')
@patch('tempfile.NamedTemporaryFile')
def test_scrape_page_with_markitdown(mock_tempfile, mock_os_remove):
# Arrange
mock_file = MagicMock()
mock_file.name = "dummy_path"
mock_tempfile.return_value.__enter__.return_value = mock_file
db = DummyDB()
scraper = Scraper(base_url='http://example.com', exclude_patterns=[], db_manager=db)
html = '<html><head><title>Test</title></head><body><h1>A Title</h1><p>This is a paragraph with <strong>bold</strong> text.</p></body></html>'
# Act
with patch('src.scraper.MarkItDown') as mock_markdown:
mock_markdown.return_value.convert.return_value = "# A Title\n\nThis is a paragraph with **bold** text."
content, metadata = scraper.scrape_page(html, 'http://example.com/test')
# Assert
assert content == '# A Title\n\nThis is a paragraph with **bold** text.'
assert metadata.get('title') == 'Test'
import requests
import tqdm
class ListDB(DummyDB):
def __init__(self):
self.links = []
self.visited = set()
self.pages = []
def insert_link(self, url, visited=False):
urls = url if isinstance(url, list) else [url]
inserted = False
for u in urls:
if u not in self.links:
self.links.append(u)
inserted = True
return inserted
def get_unvisited_links(self):
return [(u,) for u in self.links if u not in self.visited]
def mark_link_visited(self, url):
self.visited.add(url)
def get_links_count(self):
return len(self.links)
def get_visited_links_count(self):
return len(self.visited)
def insert_page(self, url, content, metadata):
self.pages.append((url, content, metadata))
def get_all_pages(self):
return self.pages
def test_start_scraping_process(monkeypatch):
db = ListDB()
scraper = Scraper(base_url='http://example.com', exclude_patterns=[], db_manager=db)
monkeypatch.setattr(Scraper, 'fetch_links', lambda self, url, html=None: set())
monkeypatch.setattr(Scraper, 'scrape_page', lambda self, html, url: ('# MD', {'url': url}))
class DummyResp:
status_code = 200
headers = {'content-type': 'text/html'}
content = b'<html></html>'
text = '<html></html>'
monkeypatch.setattr(requests, 'get', lambda url: DummyResp())
class DummyTqdm:
def __init__(self, *a, **k):
self.total = k.get('total', 0)
def update(self, n):
pass
def refresh(self):
pass
def close(self):
pass
monkeypatch.setattr(tqdm, 'tqdm', lambda *a, **k: DummyTqdm(*a, **k))
scraper.start_scraping(url='http://example.com/page')
assert db.get_links_count() == 1
assert db.get_visited_links_count() == 1
assert db.pages[0][0] == 'http://example.com/page'