Skip to content

Commit f3dc736

Browse files
committed
Add user authentication support to archive_sync command as well
The previous modification 8c1c2b0 added support for user authentication on the local webserver but missed archive_sync command.
1 parent b93b132 commit f3dc736

File tree

3 files changed

+53
-45
lines changed

3 files changed

+53
-45
lines changed

pghoard/archive_sync.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import os
1111
import sys
1212

13-
import requests
13+
from requests import Session
14+
from requests.auth import HTTPBasicAuth
1415
from rohmu.errors import InvalidConfigurationError
1516

1617
from pghoard.common import get_pg_wal_directory
@@ -34,19 +35,23 @@ def __init__(self):
3435
self.site = None
3536
self.backup_site = None
3637
self.base_url = None
38+
self.session = None
3739

3840
def set_config(self, config_file, site):
3941
self.config = config.read_json_config_file(config_file, check_commands=False)
4042
self.site = config.get_site_from_config(self.config, site)
4143
self.backup_site = self.config["backup_sites"][self.site]
4244
self.base_url = "http://127.0.0.1:{}/{}".format(self.config["http_port"], self.site)
45+
self.session = Session()
46+
if self.config.get("webserver_username") and self.config.get("webserver_password"):
47+
self.session.auth = HTTPBasicAuth(self.config["webserver_username"], self.config["webserver_password"])
4348

4449
def get_current_wal_file(self):
4550
# identify the (must be) local database
4651
return wal.get_current_lsn(self.backup_site["nodes"][0]).walfile_name
4752

4853
def get_first_required_wal_segment(self):
49-
resp = requests.get("{base}/basebackup".format(base=self.base_url))
54+
resp = self.session.get("{base}/basebackup".format(base=self.base_url))
5055
if resp.status_code != 200:
5156
self.log.error("Error looking up basebackups")
5257
return None, None
@@ -106,7 +111,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks):
106111
archive_type = "WAL"
107112

108113
if archive_type:
109-
resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
114+
resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
110115
if resp.status_code == 200:
111116
remote_hash = resp.headers.get("metadata-hash")
112117
hash_algorithm = resp.headers.get("metadata-hash-algorithm")
@@ -147,7 +152,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks):
147152
need_archival.append(wal_file)
148153

149154
for wal_file in sorted(need_archival): # sort oldest to newest
150-
resp = requests.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
155+
resp = self.session.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
151156
archive_type = "TIMELINE" if ".history" in wal_file else "WAL"
152157
if resp.status_code != 201:
153158
self.log.error("%s file %r archival failed with status code %r", archive_type, wal_file, resp.status_code)
@@ -175,7 +180,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure):
175180
# Decrement one segment if we're on a valid timeline
176181
current_lsn = current_lsn.previous_walfile_start_lsn
177182
wal_file = current_lsn.walfile_name
178-
resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
183+
resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
179184
if resp.status_code == 200:
180185
self.log.info("%s file %r correctly archived", archive_type, wal_file)
181186
file_count += 1
@@ -201,7 +206,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure):
201206
current_lsn = current_lsn.at_timeline(current_lsn.timeline_id - 1)
202207

203208
def request_basebackup(self):
204-
resp = requests.put("{base}/archive/basebackup".format(base=self.base_url))
209+
resp = self.session.put("{base}/archive/basebackup".format(base=self.base_url))
205210
if resp.status_code != 201:
206211
self.log.error("Request for a new backup for site: %r failed", self.site)
207212
else:

test/test_archivesync.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import hashlib
22
import os
3-
from unittest.mock import Mock, patch
3+
from unittest.mock import Mock
44

55
import pytest
66

@@ -43,9 +43,7 @@ def requests_head_call_return(*args, **kwargs): # pylint: disable=unused-argume
4343
return HTTPResult(status_code)
4444

4545

46-
@patch("requests.head")
47-
@patch("requests.put")
48-
def test_check_wal_archive_integrity(requests_put_mock, requests_head_mock, tmpdir):
46+
def test_check_wal_archive_integrity(tmpdir):
4947
from pghoard.archive_sync import ArchiveSync, SyncError
5048

5149
# Instantiate a fake PG data directory
@@ -57,59 +55,57 @@ def test_check_wal_archive_integrity(requests_put_mock, requests_head_mock, tmpd
5755
write_json_file(config_file, {"http_port": 8080, "backup_sites": {"foo": {"pg_data_directory": pg_data_directory}}})
5856
arsy = ArchiveSync()
5957
arsy.set_config(config_file, site="foo")
60-
requests_put_mock.return_value = HTTPResult(201) # So the backup requests succeeds
61-
requests_head_mock.side_effect = requests_head_call_return
58+
arsy.session.put = Mock(return_value=HTTPResult(201)) # So the backup requests succeeds
59+
arsy.session.head = Mock(side_effect=requests_head_call_return)
6260

6361
# Check integrity within same timeline
6462
arsy.get_current_wal_file = Mock(return_value="00000005000000000000008F")
6563
arsy.get_first_required_wal_segment = Mock(return_value=("00000005000000000000008C", 90300))
6664
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
67-
assert requests_head_mock.call_count == 3
68-
assert requests_put_mock.call_count == 0
65+
assert arsy.session.head.call_count == 3
66+
assert arsy.session.put.call_count == 0
6967

7068
# Check integrity when timeline has changed
71-
requests_head_mock.call_count = 0
72-
requests_put_mock.call_count = 0
69+
arsy.session.head.call_count = 0
70+
arsy.session.put.call_count = 0
7371
arsy.get_current_wal_file = Mock(return_value="000000090000000000000008")
7472
arsy.get_first_required_wal_segment = Mock(return_value=("000000080000000000000005", 90300))
7573
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
76-
assert requests_head_mock.call_count == 4
74+
assert arsy.session.head.call_count == 4
7775

78-
requests_head_mock.call_count = 0
79-
requests_put_mock.call_count = 0
76+
arsy.session.head.call_count = 0
77+
arsy.session.put.call_count = 0
8078
arsy.get_current_wal_file = Mock(return_value="000000030000000000000008")
8179
arsy.get_first_required_wal_segment = Mock(return_value=("000000030000000000000005", 90300))
8280
with pytest.raises(SyncError):
8381
arsy.check_wal_archive_integrity(new_backup_on_failure=False)
84-
assert requests_put_mock.call_count == 0
82+
assert arsy.session.put.call_count == 0
8583
assert arsy.check_wal_archive_integrity(new_backup_on_failure=True) == 0
86-
assert requests_put_mock.call_count == 1
84+
assert arsy.session.put.call_count == 1
8785

88-
requests_head_mock.call_count = 0
89-
requests_put_mock.call_count = 0
86+
arsy.session.head.call_count = 0
87+
arsy.session.put.call_count = 0
9088
arsy.get_current_wal_file = Mock(return_value="000000070000000000000002")
9189
arsy.get_first_required_wal_segment = Mock(return_value=("000000060000000000000001", 90300))
9290
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
93-
assert requests_put_mock.call_count == 0
91+
assert arsy.session.put.call_count == 0
9492

95-
requests_head_mock.call_count = 0
96-
requests_put_mock.call_count = 0
93+
arsy.session.head.call_count = 0
94+
arsy.session.put.call_count = 0
9795
arsy.get_current_wal_file = Mock(return_value="000000020000000B00000000")
9896
arsy.get_first_required_wal_segment = Mock(return_value=("000000020000000A000000FD", 90200))
9997
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
100-
assert requests_put_mock.call_count == 0
98+
assert arsy.session.put.call_count == 0
10199

102-
requests_head_mock.call_count = 0
103-
requests_put_mock.call_count = 0
100+
arsy.session.head.call_count = 0
101+
arsy.session.put.call_count = 0
104102
arsy.get_current_wal_file = Mock(return_value="000000020000000B00000000")
105103
arsy.get_first_required_wal_segment = Mock(return_value=("000000020000000A000000FD", 90300))
106104
assert arsy.check_wal_archive_integrity(new_backup_on_failure=True) == 0
107-
assert requests_put_mock.call_count == 1
105+
assert arsy.session.put.call_count == 1
108106

109107

110-
@patch("requests.head")
111-
@patch("requests.put")
112-
def test_check_and_upload_missing_local_files(requests_put_mock, requests_head_mock, tmpdir):
108+
def test_check_and_upload_missing_local_files(tmpdir):
113109
from pghoard.archive_sync import ArchiveSync
114110

115111
data_dir = str(tmpdir)
@@ -157,8 +153,8 @@ def requests_put(*args, **kwargs): # pylint: disable=unused-argument
157153
write_json_file(config_file, {"http_port": 8080, "backup_sites": {"foo": {"pg_data_directory": data_dir}}})
158154
arsy = ArchiveSync()
159155
arsy.set_config(config_file, site="foo")
160-
requests_put_mock.side_effect = requests_put
161-
requests_head_mock.side_effect = requests_head
156+
arsy.session.put = Mock(side_effect=requests_put)
157+
arsy.session.head = Mock(side_effect=requests_head)
162158
arsy.get_current_wal_file = Mock(return_value="00000000000000000000001A")
163159
arsy.get_first_required_wal_segment = Mock(return_value=("000000000000000000000001", 90300))
164160

test/test_webserver.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def _switch_wal(self, db, count):
242242
conn.close()
243243
return start_wal, end_wal
244244

245-
def test_archive_sync(self, db, pghoard, pg_version: str):
245+
def _test_archive_sync(self, db, pghoard, pg_version: str):
246246
log = logging.getLogger("test_archive_sync")
247247
store = pghoard.transfer_agents[0].get_object_storage(pghoard.test_site)
248248

@@ -350,36 +350,43 @@ def write_dummy_wal(inc):
350350
db.run_pg()
351351
db.run_cmd("pg_ctl", "-D", db.pgdata, "promote")
352352
time.sleep(5) # TODO: instead of sleeping, poll the db until ready
353-
# we should have a single timeline file in pg_xlog/pg_wal now
353+
# we should have one or more timeline file in pg_xlog/pg_wal now
354354
pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)}
355355
assert len(pg_wal_timelines) > 0
356-
# but there should be nothing archived as archive_command wasn't setup
356+
# but there should be one less archived as archive_command wasn't setup/active
357357
archived_timelines = set(list_archive("timeline"))
358-
assert len(archived_timelines) == 0
358+
assert len(archived_timelines) == len(pg_wal_timelines) - 1
359359
# let's hit archive sync
360360
arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])
361361
# now we should have an archived timeline
362362
archived_timelines = set(list_archive("timeline"))
363363
assert archived_timelines.issuperset(pg_wal_timelines)
364-
assert "00000002.history" in archived_timelines
365364

366365
# let's take a new basebackup
367366
self._run_and_wait_basebackup(pghoard, db, "basic")
367+
368368
# nuke archives and resync them
369369
for name in list_archive(folder="timeline"):
370370
store.delete_key(os.path.join(pghoard.test_site, "timeline", name))
371371
for name in list_archive(folder="xlog"):
372372
store.delete_key(os.path.join(pghoard.test_site, "xlog", name))
373-
self._switch_wal(db, 1)
373+
374+
start_wal, _ = self._switch_wal(db, 1)
375+
pg_wals = {f for f in os.listdir(pg_wal_dir) if wal.WAL_RE.match(f) and f > start_wal}
376+
pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)}
374377

375378
arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])
376379

377380
archived_wals = set(list_archive("xlog"))
378-
# assume the same timeline file as before and one to three wal files
379-
assert len(archived_wals) >= 1
380-
assert len(archived_wals) <= 3
381+
assert archived_wals.issuperset(pg_wals)
381382
archived_timelines = set(list_archive("timeline"))
382-
assert list(archived_timelines) == ["00000002.history"]
383+
assert archived_timelines.issuperset(pg_wal_timelines)
384+
385+
def test_archive_sync(self, db, pghoard, pg_version: str):
386+
self._test_archive_sync(db, pghoard, pg_version)
387+
388+
def test_archive_sync_with_userauth(self, db, pghoard_with_userauth, pg_version: str):
389+
self._test_archive_sync(db, pghoard_with_userauth, pg_version)
383390

384391
def test_archive_command_with_invalid_file(self, pghoard):
385392
# only WAL and timeline (.history) files can be archived

0 commit comments

Comments
 (0)