Skip to content

Commit 738f571

Browse files
committed
Add user authentication support to archive_sync command as well
The previous modification, 8c1c2b0, added support for user authentication on the local webserver but missed archive_sync command.
1 parent 69d1115 commit 738f571

File tree

3 files changed

+56
-48
lines changed

3 files changed

+56
-48
lines changed

pghoard/archive_sync.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import os
1111
import sys
1212

13-
import requests
13+
from requests import Session
14+
from requests.auth import HTTPBasicAuth
1415
from rohmu.errors import InvalidConfigurationError
1516

1617
from pghoard.common import get_pg_wal_directory
@@ -34,19 +35,23 @@ def __init__(self):
3435
self.site = None
3536
self.backup_site = None
3637
self.base_url = None
38+
self.session = None
3739

3840
def set_config(self, config_file, site):
3941
self.config = config.read_json_config_file(config_file, check_commands=False)
4042
self.site = config.get_site_from_config(self.config, site)
4143
self.backup_site = self.config["backup_sites"][self.site]
4244
self.base_url = "http://127.0.0.1:{}/{}".format(self.config["http_port"], self.site)
45+
self.session = Session()
46+
if self.config.get("webserver_username") and self.config.get("webserver_password"):
47+
self.session.auth = HTTPBasicAuth(self.config["webserver_username"], self.config["webserver_password"])
4348

4449
def get_current_wal_file(self):
4550
# identify the (must be) local database
4651
return wal.get_current_lsn(self.backup_site["nodes"][0]).walfile_name
4752

4853
def get_first_required_wal_segment(self):
49-
resp = requests.get("{base}/basebackup".format(base=self.base_url))
54+
resp = self.session.get("{base}/basebackup".format(base=self.base_url))
5055
if resp.status_code != 200:
5156
self.log.error("Error looking up basebackups")
5257
return None, None
@@ -106,7 +111,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks):
106111
archive_type = "WAL"
107112

108113
if archive_type:
109-
resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
114+
resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
110115
if resp.status_code == 200:
111116
remote_hash = resp.headers.get("metadata-hash")
112117
hash_algorithm = resp.headers.get("metadata-hash-algorithm")
@@ -147,7 +152,7 @@ def check_and_upload_missing_local_files(self, max_hash_checks):
147152
need_archival.append(wal_file)
148153

149154
for wal_file in sorted(need_archival): # sort oldest to newest
150-
resp = requests.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
155+
resp = self.session.put("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
151156
archive_type = "TIMELINE" if ".history" in wal_file else "WAL"
152157
if resp.status_code != 201:
153158
self.log.error("%s file %r archival failed with status code %r", archive_type, wal_file, resp.status_code)
@@ -175,7 +180,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure):
175180
# Decrement one segment if we're on a valid timeline
176181
current_lsn = current_lsn.previous_walfile_start_lsn
177182
wal_file = current_lsn.walfile_name
178-
resp = requests.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
183+
resp = self.session.head("{base}/archive/{file}".format(base=self.base_url, file=wal_file))
179184
if resp.status_code == 200:
180185
self.log.info("%s file %r correctly archived", archive_type, wal_file)
181186
file_count += 1
@@ -201,7 +206,7 @@ def check_wal_archive_integrity(self, new_backup_on_failure):
201206
current_lsn = current_lsn.at_timeline(current_lsn.timeline_id - 1)
202207

203208
def request_basebackup(self):
204-
resp = requests.put("{base}/archive/basebackup".format(base=self.base_url))
209+
resp = self.session.put("{base}/archive/basebackup".format(base=self.base_url))
205210
if resp.status_code != 201:
206211
self.log.error("Request for a new backup for site: %r failed", self.site)
207212
else:

test/test_archivesync.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import hashlib
22
import os
3-
from unittest.mock import Mock, patch
3+
from unittest.mock import Mock
44

55
import pytest
66

@@ -43,9 +43,7 @@ def requests_head_call_return(*args, **kwargs): # pylint: disable=unused-argume
4343
return HTTPResult(status_code)
4444

4545

46-
@patch("requests.head")
47-
@patch("requests.put")
48-
def test_check_wal_archive_integrity(requests_put_mock, requests_head_mock, tmpdir):
46+
def test_check_wal_archive_integrity(tmpdir):
4947
from pghoard.archive_sync import ArchiveSync, SyncError
5048

5149
# Instantiate a fake PG data directory
@@ -57,59 +55,57 @@ def test_check_wal_archive_integrity(requests_put_mock, requests_head_mock, tmpd
5755
write_json_file(config_file, {"http_port": 8080, "backup_sites": {"foo": {"pg_data_directory": pg_data_directory}}})
5856
arsy = ArchiveSync()
5957
arsy.set_config(config_file, site="foo")
60-
requests_put_mock.return_value = HTTPResult(201) # So the backup requests succeeds
61-
requests_head_mock.side_effect = requests_head_call_return
58+
arsy.session.put = Mock(return_value=HTTPResult(201)) # So the backup requests succeeds
59+
arsy.session.head = Mock(side_effect=requests_head_call_return)
6260

6361
# Check integrity within same timeline
6462
arsy.get_current_wal_file = Mock(return_value="00000005000000000000008F")
6563
arsy.get_first_required_wal_segment = Mock(return_value=("00000005000000000000008C", 90300))
6664
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
67-
assert requests_head_mock.call_count == 3
68-
assert requests_put_mock.call_count == 0
65+
assert arsy.session.head.call_count == 3
66+
assert arsy.session.put.call_count == 0
6967

7068
# Check integrity when timeline has changed
71-
requests_head_mock.call_count = 0
72-
requests_put_mock.call_count = 0
69+
arsy.session.head.call_count = 0
70+
arsy.session.put.call_count = 0
7371
arsy.get_current_wal_file = Mock(return_value="000000090000000000000008")
7472
arsy.get_first_required_wal_segment = Mock(return_value=("000000080000000000000005", 90300))
7573
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
76-
assert requests_head_mock.call_count == 4
74+
assert arsy.session.head.call_count == 4
7775

78-
requests_head_mock.call_count = 0
79-
requests_put_mock.call_count = 0
76+
arsy.session.head.call_count = 0
77+
arsy.session.put.call_count = 0
8078
arsy.get_current_wal_file = Mock(return_value="000000030000000000000008")
8179
arsy.get_first_required_wal_segment = Mock(return_value=("000000030000000000000005", 90300))
8280
with pytest.raises(SyncError):
8381
arsy.check_wal_archive_integrity(new_backup_on_failure=False)
84-
assert requests_put_mock.call_count == 0
82+
assert arsy.session.put.call_count == 0
8583
assert arsy.check_wal_archive_integrity(new_backup_on_failure=True) == 0
86-
assert requests_put_mock.call_count == 1
84+
assert arsy.session.put.call_count == 1
8785

88-
requests_head_mock.call_count = 0
89-
requests_put_mock.call_count = 0
86+
arsy.session.head.call_count = 0
87+
arsy.session.put.call_count = 0
9088
arsy.get_current_wal_file = Mock(return_value="000000070000000000000002")
9189
arsy.get_first_required_wal_segment = Mock(return_value=("000000060000000000000001", 90300))
9290
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
93-
assert requests_put_mock.call_count == 0
91+
assert arsy.session.put.call_count == 0
9492

95-
requests_head_mock.call_count = 0
96-
requests_put_mock.call_count = 0
93+
arsy.session.head.call_count = 0
94+
arsy.session.put.call_count = 0
9795
arsy.get_current_wal_file = Mock(return_value="000000020000000B00000000")
9896
arsy.get_first_required_wal_segment = Mock(return_value=("000000020000000A000000FD", 90200))
9997
assert arsy.check_wal_archive_integrity(new_backup_on_failure=False) == 0
100-
assert requests_put_mock.call_count == 0
98+
assert arsy.session.put.call_count == 0
10199

102-
requests_head_mock.call_count = 0
103-
requests_put_mock.call_count = 0
100+
arsy.session.head.call_count = 0
101+
arsy.session.put.call_count = 0
104102
arsy.get_current_wal_file = Mock(return_value="000000020000000B00000000")
105103
arsy.get_first_required_wal_segment = Mock(return_value=("000000020000000A000000FD", 90300))
106104
assert arsy.check_wal_archive_integrity(new_backup_on_failure=True) == 0
107-
assert requests_put_mock.call_count == 1
105+
assert arsy.session.put.call_count == 1
108106

109107

110-
@patch("requests.head")
111-
@patch("requests.put")
112-
def test_check_and_upload_missing_local_files(requests_put_mock, requests_head_mock, tmpdir):
108+
def test_check_and_upload_missing_local_files(tmpdir):
113109
from pghoard.archive_sync import ArchiveSync
114110

115111
data_dir = str(tmpdir)
@@ -157,8 +153,8 @@ def requests_put(*args, **kwargs): # pylint: disable=unused-argument
157153
write_json_file(config_file, {"http_port": 8080, "backup_sites": {"foo": {"pg_data_directory": data_dir}}})
158154
arsy = ArchiveSync()
159155
arsy.set_config(config_file, site="foo")
160-
requests_put_mock.side_effect = requests_put
161-
requests_head_mock.side_effect = requests_head
156+
arsy.session.put = Mock(side_effect=requests_put)
157+
arsy.session.head = Mock(side_effect=requests_head)
162158
arsy.get_current_wal_file = Mock(return_value="00000000000000000000001A")
163159
arsy.get_first_required_wal_segment = Mock(return_value=("000000000000000000000001", 90300))
164160

test/test_webserver.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def _switch_wal(self, db, count):
242242
conn.close()
243243
return start_wal, end_wal
244244

245-
def test_archive_sync(self, db, pghoard):
245+
def _test_archive_sync(self, db, pghoard):
246246
log = logging.getLogger("test_archive_sync")
247247
store = pghoard.transfer_agents[0].get_object_storage(pghoard.test_site)
248248

@@ -267,13 +267,13 @@ def list_archive(folder):
267267
self._run_and_wait_basebackup(pghoard, db, "pipe")
268268

269269
# force a couple of wal segment switches
270-
start_wal, _ = self._switch_wal(db, 4)
270+
start_wal, end_wal = self._switch_wal(db, 4)
271271
# we should have at least 4 WAL files now (there may be more in
272272
# case other tests created them -- we share a single postresql
273273
# cluster between all tests)
274274
pg_wal_dir = get_pg_wal_directory(pghoard.config["backup_sites"][pghoard.test_site])
275275
pg_wals = {f for f in os.listdir(pg_wal_dir) if wal.WAL_RE.match(f) and f > start_wal}
276-
assert len(pg_wals) >= 4
276+
assert len(pg_wals) == int(end_wal, 16) - int(start_wal, 16)
277277

278278
# create a couple of "recycled" xlog files that we must ignore
279279
last_wal = sorted(pg_wals)[-1]
@@ -291,7 +291,7 @@ def write_dummy_wal(inc):
291291
# check what we have archived, there should be at least the three
292292
# above WALs that are NOT there at the moment
293293
archived_wals = set(list_archive("xlog"))
294-
assert len(pg_wals - archived_wals) >= 4
294+
assert len(pg_wals - archived_wals) >= 3
295295
# now perform an archive sync
296296
arsy = ArchiveSync()
297297
arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])
@@ -346,36 +346,43 @@ def write_dummy_wal(inc):
346346
db.run_pg()
347347
db.run_cmd("pg_ctl", "-D", db.pgdata, "promote")
348348
time.sleep(5) # TODO: instead of sleeping, poll the db until ready
349-
# we should have a single timeline file in pg_xlog/pg_wal now
349+
# we should have one or more timeline file in pg_xlog/pg_wal now
350350
pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)}
351351
assert len(pg_wal_timelines) > 0
352-
# but there should be nothing archived as archive_command wasn't setup
352+
# but there should be one less archived as archive_command wasn't setup/active
353353
archived_timelines = set(list_archive("timeline"))
354-
assert len(archived_timelines) == 0
354+
assert len(archived_timelines) == len(pg_wal_timelines) - 1
355355
# let's hit archive sync
356356
arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])
357357
# now we should have an archived timeline
358358
archived_timelines = set(list_archive("timeline"))
359359
assert archived_timelines.issuperset(pg_wal_timelines)
360-
assert "00000002.history" in archived_timelines
361360

362361
# let's take a new basebackup
363362
self._run_and_wait_basebackup(pghoard, db, "basic")
363+
364364
# nuke archives and resync them
365365
for name in list_archive(folder="timeline"):
366366
store.delete_key(os.path.join(pghoard.test_site, "timeline", name))
367367
for name in list_archive(folder="xlog"):
368368
store.delete_key(os.path.join(pghoard.test_site, "xlog", name))
369-
self._switch_wal(db, 1)
369+
370+
start_wal, _ = self._switch_wal(db, 1)
371+
pg_wals = {f for f in os.listdir(pg_wal_dir) if wal.WAL_RE.match(f) and f > start_wal}
372+
pg_wal_timelines = {f for f in os.listdir(pg_wal_dir) if wal.TIMELINE_RE.match(f)}
370373

371374
arsy.run(["--site", pghoard.test_site, "--config", pghoard.config_path])
372375

373376
archived_wals = set(list_archive("xlog"))
374-
# assume the same timeline file as before and one to three wal files
375-
assert len(archived_wals) >= 1
376-
assert len(archived_wals) <= 3
377+
assert archived_wals.issuperset(pg_wals)
377378
archived_timelines = set(list_archive("timeline"))
378-
assert list(archived_timelines) == ["00000002.history"]
379+
assert archived_timelines.issuperset(pg_wal_timelines)
380+
381+
def test_archive_sync(self, db, pghoard):
382+
self._test_archive_sync(db, pghoard)
383+
384+
def test_archive_sync_with_userauth(self, db, pghoard_with_userauth):
385+
self._test_archive_sync(db, pghoard_with_userauth)
379386

380387
def test_archive_command_with_invalid_file(self, pghoard):
381388
# only WAL and timeline (.history) files can be archived

0 commit comments

Comments
 (0)