|
5 | 5 | hashing/dedup logic. Uses temp files — no conda installation needed. |
6 | 6 | """ |
7 | 7 |
|
| 8 | +import fcntl |
8 | 9 | import json |
9 | 10 | import os |
10 | 11 | import tempfile |
|
13 | 14 | import pytest |
14 | 15 |
|
15 | 16 |
|
16 | | -class TestManifestConcurrentWrites: |
17 | | - """Test that manifest file operations are thread-safe.""" |
| 17 | +def _flock_on_data_file(manifest_path, key, value): |
| 18 | + """Old pattern: flock the data file itself + seek/truncate. |
| 19 | +
|
| 20 | + Serializes correctly but not crash-safe (truncate then write). |
| 21 | + """ |
| 22 | + with open(manifest_path, "r+") as f: |
| 23 | + fcntl.flock(f.fileno(), fcntl.LOCK_EX) |
| 24 | + try: |
| 25 | + data = json.load(f) |
| 26 | + data[key] = value |
| 27 | + f.seek(0) |
| 28 | + f.truncate() |
| 29 | + json.dump(data, f) |
| 30 | + f.flush() |
| 31 | + os.fsync(f.fileno()) |
| 32 | + finally: |
| 33 | + fcntl.flock(f.fileno(), fcntl.LOCK_UN) |
18 | 34 |
|
19 | | - def test_concurrent_writes_no_corruption(self): |
20 | | - """Multiple threads writing to the manifest should not corrupt it.""" |
21 | | - from metaflow.util import atomic_json_update |
22 | 35 |
|
23 | | - errors = [] |
24 | | - num_threads = 10 |
25 | | - writes_per_thread = 5 |
26 | | - |
27 | | - with tempfile.NamedTemporaryFile( |
28 | | - mode="w", suffix=".manifest", delete=False |
29 | | - ) as f: |
30 | | - manifest_path = f.name |
31 | | - json.dump({}, f) |
32 | | - |
33 | | - def write_to_manifest(thread_id): |
34 | | - try: |
35 | | - for i in range(writes_per_thread): |
36 | | - key = f"env_{thread_id}_{i}" |
37 | | - atomic_json_update( |
38 | | - manifest_path, |
39 | | - lambda d, k=key, t=thread_id: { |
40 | | - **d, |
41 | | - k: {"platform": "linux-64", "thread": t}, |
42 | | - }, |
43 | | - ) |
44 | | - except Exception as e: |
45 | | - errors.append(e) |
46 | | - |
47 | | - threads = [ |
48 | | - threading.Thread(target=write_to_manifest, args=(tid,)) |
49 | | - for tid in range(num_threads) |
50 | | - ] |
51 | | - for t in threads: |
52 | | - t.start() |
53 | | - for t in threads: |
54 | | - t.join() |
| 36 | +def _flock_on_data_file_with_replace(manifest_path, key, value): |
| 37 | + """Broken pattern: flock the data file + os.replace. |
55 | 38 |
|
| 39 | + The os.replace creates a new inode, so the flock on the old file |
| 40 | + descriptor doesn't prevent other threads from acquiring their own |
| 41 | + lock on the new inode. Writes are lost under contention. |
| 42 | + """ |
| 43 | + with open(manifest_path, "r+") as f: |
| 44 | + fcntl.flock(f.fileno(), fcntl.LOCK_EX) |
56 | 45 | try: |
57 | | - assert len(errors) == 0, f"Errors during concurrent writes: {errors}" |
| 46 | + data = json.load(f) |
| 47 | + data[key] = value |
| 48 | + tmp = manifest_path + ".tmp" |
| 49 | + with open(tmp, "w") as tmp_f: |
| 50 | + json.dump(data, tmp_f) |
| 51 | + os.replace(tmp, manifest_path) |
| 52 | + finally: |
| 53 | + fcntl.flock(f.fileno(), fcntl.LOCK_UN) |
| 54 | + |
| 55 | + |
| 56 | +def _no_locking(manifest_path, key, value): |
| 57 | + """Old production_token.py pattern: no locking at all.""" |
| 58 | + with open(manifest_path, "r") as f: |
| 59 | + data = json.load(f) |
| 60 | + data[key] = value |
| 61 | + with open(manifest_path, "w") as f: |
| 62 | + json.dump(data, f) |
| 63 | + |
58 | 64 |
|
| 65 | +def _atomic_json_update_wrapper(manifest_path, key, value): |
| 66 | + """New pattern: lock file + atomic replace.""" |
| 67 | + from metaflow.util import atomic_json_update |
| 68 | + |
| 69 | + atomic_json_update(manifest_path, lambda d: {**d, key: value}) |
| 70 | + |
| 71 | + |
| 72 | +def _run_concurrent_writes(write_fn, num_threads=10, writes_per_thread=5): |
| 73 | + """Run concurrent writes using the given write function. |
| 74 | +
|
| 75 | + Returns (num_keys_written, expected_keys, errors). |
| 76 | + """ |
| 77 | + errors = [] |
| 78 | + |
| 79 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".manifest", delete=False) as f: |
| 80 | + manifest_path = f.name |
| 81 | + json.dump({}, f) |
| 82 | + |
| 83 | + def writer(thread_id): |
| 84 | + try: |
| 85 | + for i in range(writes_per_thread): |
| 86 | + key = f"env_{thread_id}_{i}" |
| 87 | + write_fn(manifest_path, key, {"thread": thread_id}) |
| 88 | + except Exception as e: |
| 89 | + errors.append(e) |
| 90 | + |
| 91 | + threads = [ |
| 92 | + threading.Thread(target=writer, args=(tid,)) for tid in range(num_threads) |
| 93 | + ] |
| 94 | + for t in threads: |
| 95 | + t.start() |
| 96 | + for t in threads: |
| 97 | + t.join() |
| 98 | + |
| 99 | + try: |
| 100 | + try: |
59 | 101 | with open(manifest_path, "r") as f: |
60 | 102 | data = json.load(f) |
| 103 | + return len(data), num_threads * writes_per_thread, errors |
| 104 | + except (json.JSONDecodeError, FileNotFoundError) as e: |
| 105 | + # File was corrupted or deleted by the broken pattern |
| 106 | + errors.append(e) |
| 107 | + return 0, num_threads * writes_per_thread, errors |
| 108 | + finally: |
| 109 | + if os.path.exists(manifest_path): |
| 110 | + os.unlink(manifest_path) |
| 111 | + for suffix in (".lock", ".tmp"): |
| 112 | + p = manifest_path + suffix |
| 113 | + if os.path.exists(p): |
| 114 | + os.unlink(p) |
61 | 115 |
|
62 | | - expected_keys = num_threads * writes_per_thread |
63 | | - assert len(data) == expected_keys, ( |
64 | | - f"Expected {expected_keys} entries, got {len(data)}. " |
65 | | - "Some writes may have been lost due to race conditions." |
| 116 | + |
| 117 | +class TestManifestConcurrentWrites: |
| 118 | + """Demonstrate that the old locking patterns lose writes and the new one doesn't.""" |
| 119 | + |
| 120 | + @pytest.mark.parametrize( |
| 121 | + "write_fn,description", |
| 122 | + [ |
| 123 | + ( |
| 124 | + _flock_on_data_file_with_replace, |
| 125 | + "flock+replace (inode changes break lock)", |
| 126 | + ), |
| 127 | + (_no_locking, "no locking (old production_token.py pattern)"), |
| 128 | + ], |
| 129 | + ) |
| 130 | + def test_broken_patterns_lose_writes(self, write_fn, description): |
| 131 | + """Broken locking patterns lose writes or corrupt data under contention. |
| 132 | +
|
| 133 | + These patterns are known-broken. The race is non-deterministic, so we |
| 134 | + run many trials with high contention. If the race never triggers, we |
| 135 | + skip rather than fail — the important thing is that the *fixed* pattern |
| 136 | + (test_atomic_json_update_no_lost_writes) always passes. |
| 137 | + """ |
| 138 | + broken_at_least_once = False |
| 139 | + for _ in range(30): |
| 140 | + actual, expected, errors = _run_concurrent_writes( |
| 141 | + write_fn, num_threads=20, writes_per_thread=10 |
| 142 | + ) |
| 143 | + if errors or actual < expected: |
| 144 | + broken_at_least_once = True |
| 145 | + break |
| 146 | + |
| 147 | + if not broken_at_least_once: |
| 148 | + pytest.skip( |
| 149 | + f"Race condition in '{description}' did not trigger in 30 trials " |
| 150 | + f"(non-deterministic). The pattern is still broken; the OS scheduler " |
| 151 | + f"just didn't interleave threads enough to expose it this time." |
| 152 | + ) |
| 153 | + |
| 154 | + def test_atomic_json_update_no_lost_writes(self): |
| 155 | + """atomic_json_update with lock file never loses writes.""" |
| 156 | + for _ in range(5): |
| 157 | + actual, expected, errors = _run_concurrent_writes( |
| 158 | + _atomic_json_update_wrapper, num_threads=10, writes_per_thread=5 |
| 159 | + ) |
| 160 | + assert len(errors) == 0, f"Unexpected errors: {errors}" |
| 161 | + assert actual == expected, ( |
| 162 | + f"Expected {expected} entries, got {actual}. " |
| 163 | + "atomic_json_update lost writes." |
| 164 | + ) |
| 165 | + |
| 166 | + def test_flock_seek_truncate_no_lost_writes(self): |
| 167 | + """flock on data file + seek/truncate doesn't lose writes (same inode).""" |
| 168 | + for _ in range(5): |
| 169 | + actual, expected, errors = _run_concurrent_writes( |
| 170 | + _flock_on_data_file, num_threads=10, writes_per_thread=5 |
| 171 | + ) |
| 172 | + assert len(errors) == 0, f"Unexpected errors: {errors}" |
| 173 | + assert actual == expected, ( |
| 174 | + f"Expected {expected} entries, got {actual}. " |
| 175 | + "flock+seek/truncate lost writes." |
66 | 176 | ) |
67 | | - finally: |
68 | | - os.unlink(manifest_path) |
69 | | - # Clean up lock file |
70 | | - lock_path = manifest_path + ".lock" |
71 | | - if os.path.exists(lock_path): |
72 | | - os.unlink(lock_path) |
73 | 177 |
|
74 | 178 |
|
75 | 179 | class TestAtomicJsonUpdate: |
@@ -114,28 +218,3 @@ def bad_update(d): |
114 | 218 |
|
115 | 219 | with open(path, "r") as f: |
116 | 220 | assert json.load(f) == {"original": True} |
117 | | - |
118 | | - |
119 | | -class TestCleanupCondaFile: |
120 | | - """Test temp file cleanup patterns used by the conda environment.""" |
121 | | - |
122 | | - def test_cleanup_temp_file(self): |
123 | | - """Verify temp files are cleaned up properly.""" |
124 | | - with tempfile.NamedTemporaryFile(suffix=".conda_tmp", delete=False) as f: |
125 | | - tmp_path = f.name |
126 | | - f.write(b"test data") |
127 | | - |
128 | | - assert os.path.exists(tmp_path) |
129 | | - os.unlink(tmp_path) |
130 | | - assert not os.path.exists(tmp_path) |
131 | | - |
132 | | - def test_cleanup_nonexistent_file_no_error(self): |
133 | | - """Cleaning up a file that doesn't exist should not raise.""" |
134 | | - tmp_path = "/tmp/nonexistent_conda_test_file_12345.tmp" |
135 | | - if os.path.exists(tmp_path): |
136 | | - os.unlink(tmp_path) |
137 | | - # Should not raise |
138 | | - try: |
139 | | - os.unlink(tmp_path) |
140 | | - except FileNotFoundError: |
141 | | - pass # Expected behavior |
|
0 commit comments