Skip to content

Commit 8ed5a9c

Browse files
committed
use process pool for testing mutations
1 parent 1ba2a7e commit 8ed5a9c

File tree

5 files changed

+218
-50
lines changed

5 files changed

+218
-50
lines changed

mutmut/__main__.py

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import platform
2+
from mutmut.custom_process_pool import Task
3+
from mutmut.custom_process_pool import CustomProcessPool
14
import ast
25
import fnmatch
36
import gc
47
import inspect
58
import itertools
69
import json
7-
from multiprocessing import Pool, Process, set_start_method
10+
from multiprocessing import JoinableQueue, Pool, Process, Queue, set_start_method
811
import multiprocessing
912
import multiprocessing.connection
1013
import os
@@ -47,7 +50,9 @@
4750
)
4851
from typing import (
4952
Dict,
53+
Generic,
5054
List,
55+
TypeVar,
5156
Union,
5257
)
5358

@@ -373,6 +378,8 @@ def new_tests(self):
373378
return self.ids - collected_test_names()
374379

375380

381+
_pytest_initialized = False
382+
376383
class PytestRunner(TestRunner):
377384
# noinspection PyMethodMayBeStatic
378385
def execute_pytest(self, params: list[str], **kwargs):
@@ -386,6 +393,10 @@ def execute_pytest(self, params: list[str], **kwargs):
386393
print(' exit code', exit_code)
387394
if exit_code == 4:
388395
raise BadTestExecutionCommandsException(params)
396+
397+
global _pytest_initialized
398+
_pytest_initialized = True
399+
389400
return exit_code
390401

391402
def run_stats(self, *, tests):
@@ -889,7 +900,10 @@ def inner_timout_checker():
889900
@click.option('--max-children', type=int)
890901
@click.argument('mutant_names', required=False, nargs=-1)
891902
def run(mutant_names, *, max_children):
892-
set_start_method('spawn')
903+
if platform.system() == 'Windows':
904+
set_start_method('spawn')
905+
else:
906+
set_start_method('fork')
893907

894908
assert isinstance(mutant_names, (tuple, list)), mutant_names
895909
_run(mutant_names, max_children)
@@ -956,23 +970,6 @@ def _run(mutant_names: Union[tuple, list], max_children: Union[None, int]):
956970

957971
running_processes: set[Process] = set()
958972

959-
def handle_finished_processes() -> int:
960-
nonlocal running_processes
961-
sentinels = [p.sentinel for p in running_processes]
962-
multiprocessing.connection.wait(sentinels)
963-
964-
finished_processes = {p for p in running_processes if not p.is_alive()}
965-
running_processes -= finished_processes
966-
967-
for p in finished_processes:
968-
if mutmut.config.debug:
969-
print(' worker exit code', p.exitcode)
970-
source_file_mutation_data_by_pid[p.pid].register_result(pid=p.pid, exit_code=p.exitcode)
971-
972-
p.close()
973-
974-
return len(finished_processes)
975-
976973
source_file_mutation_data_by_pid: dict[int, SourceFileMutationData] = {} # many pids map to one MutationData
977974
running_children = 0
978975
count_tried = 0
@@ -996,6 +993,8 @@ def handle_finished_processes() -> int:
996993
# TODO: implement timeout for windows + unix
997994
# Thread(target=timeout_checker(mutants), daemon=True).start()
998995

996+
args: list[tuple[TestRunner, SourceFileMutationData, str, list[str], Config]] = []
997+
999998
# Now do mutation
1000999
for m, mutant_name, result in mutants:
10011000
print_stats(source_file_mutation_data_by_path)
@@ -1015,37 +1014,32 @@ def handle_finished_processes() -> int:
10151014
m.save()
10161015
continue
10171016

1018-
p = Process(target=_test_mutation, args=(runner, m, mutant_name, tests, mutmut.config))
1019-
running_processes.add(p)
1020-
p.start()
1021-
pid = p.pid
1022-
# in the parent
1023-
source_file_mutation_data_by_pid[pid] = m
1024-
m.register_pid(pid=pid, key=mutant_name, estimated_time_of_tests=estimated_time_of_tests)
1025-
running_children += 1
1026-
1027-
if running_children >= max_children:
1028-
count_finished = handle_finished_processes()
1029-
count_tried += count_finished
1030-
running_children -= count_finished
1031-
1032-
try:
1033-
while running_children:
1034-
print_stats(source_file_mutation_data_by_path)
1035-
count_finished = handle_finished_processes()
1036-
count_tried += count_finished
1037-
running_children -= count_finished
1038-
except ChildProcessError:
1039-
pass
1017+
args.append((runner, m, mutant_name, tests, mutmut.config))
1018+
source_file_mutation_data_by_pid[mutant_name] = m
1019+
m.register_pid(pid=mutant_name, key=mutant_name, estimated_time_of_tests=estimated_time_of_tests)
1020+
1021+
tasks: list[Task] = []
1022+
for arg in args:
1023+
tasks.append(Task(id=arg[2], args=arg, timeout_seconds=1000))
1024+
pool = CustomProcessPool(tasks, _test_mutation, max_children)
1025+
done = 0
1026+
for finished_task in pool.run():
1027+
done += 1
1028+
# print(f'Finished {done} tasks')
1029+
if finished_task.error:
1030+
print(finished_task)
1031+
# print(finished_task)
1032+
source_file_mutation_data_by_pid[finished_task.id].register_result(pid=finished_task.id, exit_code=finished_task.result)
1033+
print_stats(source_file_mutation_data_by_path)
10401034
except KeyboardInterrupt:
1035+
pool.shutdown()
10411036
print('Stopping...')
1042-
stop_all_children(mutants)
10431037

10441038
t = datetime.now() - start
10451039

10461040
print_stats(source_file_mutation_data_by_path, force_output=True)
10471041
print()
1048-
print(f'{count_tried / t.total_seconds():.2f} mutations/second')
1042+
print(f'{len(tasks) / t.total_seconds():.2f} mutations/second')
10491043

10501044
if mutant_names:
10511045
print()
@@ -1062,12 +1056,18 @@ def handle_finished_processes() -> int:
10621056
print()
10631057

10641058

1065-
def _test_mutation(runner: TestRunner, m: SourceFileMutationData, mutant_name: str, tests, config):
1059+
def _test_mutation(task: Task):
1060+
args: tuple[TestRunner, SourceFileMutationData, str, list[str], Config] = task.args
1061+
runner, m, mutant_name, tests, config = args
10661062
try:
10671063
mutmut.config = config
10681064

1069-
with CatchOutput():
1070-
runner.list_all_tests()
1065+
# ensure that we imported all files at least once per process
1066+
# before we set MUTANT_UNDER_TEST (so everything that runs at import
1067+
# time is not mutated)
1068+
if not _pytest_initialized:
1069+
with CatchOutput():
1070+
runner.list_all_tests()
10711071

10721072
os.environ['MUTANT_UNDER_TEST'] = mutant_name
10731073
setproctitle(f'mutmut: {mutant_name}')
@@ -1082,13 +1082,14 @@ def _test_mutation(runner: TestRunner, m: SourceFileMutationData, mutant_name: s
10821082

10831083
with CatchOutput():
10841084
result = runner.run_tests(mutant_name=mutant_name, tests=tests)
1085-
os._exit(result)
1085+
1086+
return result
1087+
# os._exit(result)
10861088
except Exception as e:
10871089
with open(f'error.{mutant_name}.log', 'w') as log:
10881090
log.write(str(e))
1089-
os._exit(-1)
1090-
1091-
1091+
log.flush()
1092+
return -24
10921093

10931094
def tests_for_mutant_names(mutant_names):
10941095
tests = set()

mutmut/custom_process_pool.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from __future__ import annotations
2+
3+
from typing import Generic, Union, Any, Callable, Iterable, TypeVar
4+
from typing_extensions import ParamSpec
5+
from dataclasses import dataclass
6+
7+
import multiprocessing.connection
8+
from multiprocessing import Queue, Process
9+
import queue
10+
import os
11+
12+
13+
TaskArgs = ParamSpec('TaskArgs')
14+
TaskResult = TypeVar('TaskResult')
15+
16+
@dataclass
17+
class Task:
18+
id: str
19+
args: tuple[Any, ...]
20+
# this timeout is real time, not process cpu time
21+
timeout_seconds: int
22+
23+
@dataclass
24+
class TaskError:
25+
id: str
26+
error: Exception
27+
28+
@dataclass
29+
class FinishedTask(Generic[TaskResult]):
30+
id: str
31+
result: Union[TaskResult, None]
32+
error: Union[Exception, None]
33+
34+
class JobTimeoutException(Exception):
35+
pass
36+
37+
class CustomProcessPool(Generic[TaskArgs, TaskResult]):
38+
def __init__(self, tasks: list[Task], job: Callable[TaskArgs, TaskResult], max_workers: int):
39+
self._tasks = tasks
40+
self._job = job
41+
self._remaining_tasks_queue: Queue[Task] = Queue()
42+
self._remaining_tasks_count = len(tasks)
43+
self._results: Queue[FinishedTask[TaskResult]] = Queue()
44+
self._max_workers = max_workers
45+
self._workers: set[Process] = set()
46+
self._killed_workers = 0
47+
self._shutdown = False
48+
49+
def run(self) -> Iterable[FinishedTask]:
50+
for task in self._tasks:
51+
self._remaining_tasks_queue.put(task)
52+
53+
self._start_missing_workers()
54+
55+
while not self.done() and not self._shutdown:
56+
self._remove_stopped_workers()
57+
self._start_missing_workers()
58+
59+
yield from self._get_new_results(timeout=1)
60+
61+
self.shutdown()
62+
63+
def shutdown(self):
64+
# TODO: is this a good way to shutdown processes?
65+
for p in self._workers:
66+
if p.is_alive():
67+
p.kill()
68+
for p in self._workers:
69+
p.join()
70+
self._remaining_tasks_queue.close()
71+
self._results.close()
72+
self._shutdown = True
73+
74+
def _start_missing_workers(self):
75+
self._workers = {p for p in self._workers if p.is_alive()}
76+
77+
desired_workers = min(self._max_workers, self._remaining_tasks_count)
78+
missing_workers = desired_workers - len(self._workers)
79+
80+
for _ in range(missing_workers):
81+
self._start_worker()
82+
83+
def _remove_stopped_workers(self):
84+
"""Start a new worker for all stopped workers. We kill workers for timeouts."""
85+
killed_workers = {p for p in self._workers if not p.is_alive()}
86+
self._workers -= killed_workers
87+
88+
for worker in killed_workers:
89+
print(f'Worker {worker.pid} stopped with exitcode {worker.exitcode}')
90+
91+
def _get_new_results(self, timeout: int) -> Iterable[FinishedTask]:
92+
try:
93+
result = self._results.get(timeout=timeout)
94+
self._remaining_tasks_count -= 1
95+
yield result
96+
except queue.Empty:
97+
pass
98+
99+
def _start_worker(self):
100+
p = Process(target=CustomProcessPool._pool_job_executor, args=(self._job, self._remaining_tasks_queue, self._results))
101+
p.start()
102+
self._workers.add(p)
103+
104+
def done(self) -> bool:
105+
return self._remaining_tasks_count == 0
106+
107+
@staticmethod
108+
def _pool_job_executor(job: Callable[..., TaskResult], task_queue: Queue[Task], results: Queue[FinishedTask[TaskResult]]):
109+
while True:
110+
try:
111+
task = task_queue.get(timeout=1)
112+
# f = open(f'logs/log-{task.id}.txt', 'w')
113+
# pid = os.getpid()
114+
except queue.Empty:
115+
os._exit(0)
116+
117+
try:
118+
result = job(task)
119+
finished_task: FinishedTask[TaskResult] = FinishedTask(id=task.id, result=result, error=None)
120+
except Exception as e:
121+
finished_task = FinishedTask(id=task.id, result=None, error=e)
122+
finally:
123+
# f.write(f'Finished job: {finished_task}\n')
124+
# f.flush()
125+
results.put(finished_task)
126+
# f.write(f'Added job to queue\n')
127+
# f.write(f'Finished qsize: {results.qsize()}\n')
128+
# f.flush()
129+
130+
131+

mutmut/file_mutation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def combine_mutations_to_source(module: cst.Module, mutations: Sequence[Mutation
176176
:param mutations: Mutations that should be applied.
177177
:return: Mutated code and list of mutation names"""
178178

179+
# mutations = mutations[0:10]
180+
179181
# copy start of the module (in particular __future__ imports)
180182
result: list[MODULE_STATEMENT] = get_statements_until_func_or_class(module.body)
181183
mutation_names: list[str] = []

tests/e2e/test_e2e_result_snapshots.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,6 @@ def test_my_lib_result_snapshot():
7878
def test_config_result_snapshot():
7979
mutmut._reset_globals()
8080
asserts_results_did_not_change("config")
81+
82+
if __name__ == '__main__':
83+
test_my_lib_result_snapshot()

tests/test_custom_pool.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from mutmut.__main__ import CustomProcessPool, Task
2+
import pytest
3+
import time
4+
5+
def test_custom_process_pool():
6+
tasks = [
7+
Task(id='a-small', args=(1, 2), timeout_seconds=1000),
8+
Task(id='b-medium', args=(30, 20), timeout_seconds=1000),
9+
Task(id='c-neg', args=(-2, -2), timeout_seconds=1000),
10+
Task(id='d-div-by-zero', args=(-2, 0), timeout_seconds=1000),
11+
]
12+
pool = CustomProcessPool(tasks, _divide, max_workers=2)
13+
14+
results = []
15+
for result in pool.run():
16+
print(result)
17+
results.append(result)
18+
19+
assert len(results) == 4
20+
21+
results = sorted(results, key=lambda result: result.id)
22+
assert results[0].result == pytest.approx(0.5)
23+
assert results[1].result == pytest.approx(1.5)
24+
assert results[2].result == pytest.approx(1)
25+
assert results[3].result == None
26+
assert isinstance(results[3].error, ZeroDivisionError)
27+
28+
def _divide(task: Task):
29+
a, b = task.args
30+
# time.sleep(timeout)
31+
return a / b

0 commit comments

Comments
 (0)