diff --git a/tools/accuracy_checker/accuracy_checker/annotation_converters/convert.py b/tools/accuracy_checker/accuracy_checker/annotation_converters/convert.py index 6d508fb0c83..7066bc3b219 100644 --- a/tools/accuracy_checker/accuracy_checker/annotation_converters/convert.py +++ b/tools/accuracy_checker/accuracy_checker/annotation_converters/convert.py @@ -16,9 +16,10 @@ import warnings import platform import sys - +import os import copy import json +import tempfile from pathlib import Path import pickle # nosec B403 # disable import-pickle check from argparse import ArgumentParser @@ -327,7 +328,7 @@ def save_annotation(annotation, meta, annotation_file, meta_file, dataset_config annotation_dir = annotation_file.parent if not annotation_dir.exists(): annotation_dir.mkdir(parents=True) - with annotation_file.open('wb') as file: + with AtomicWriteFileHandle(annotation_file,'wb') as file: if conversion_meta: pickle.dump(conversion_meta, file) for representation in annotation: @@ -337,7 +338,7 @@ def save_annotation(annotation, meta, annotation_file, meta_file, dataset_config meta_dir = meta_file.parent if not meta_dir.exists(): meta_dir.mkdir(parents=True) - with meta_file.open('wt') as file: + with AtomicWriteFileHandle(meta_file, 'wt') as file: json.dump(meta, file) @@ -409,3 +410,38 @@ def analyze_dataset(annotations, metadata): else: metadata = {'data_analysis': data_analysis} return metadata + +class AtomicWriteFileHandle: + """Ensure the file is written once in case of multi processes or threads.""" + + def __init__(self, file_path, open_mode): + self.target_path = file_path + self.mode = open_mode + + self.temp_fd, self.temp_path = tempfile.mkstemp(dir=os.path.dirname(file_path)) + self.temp_file = os.fdopen(self.temp_fd, open_mode) + + def write(self, data): + self.temp_file.write(data) + + def writelines(self, lines): + self.temp_file.writelines(lines) + + def close(self): + if not self.temp_file.closed: + self.temp_file.close() + if not os.path.exists(self.target_path): + os.rename(self.temp_path, self.target_path) + else: + os.remove(self.temp_path) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + # Mimic other file object methods as needed + def __getattr__(self, item): + """Delegate attribute access to the underlying temporary file object.""" + return getattr(self.temp_file, item) diff --git a/tools/accuracy_checker/tests/test_convert.py b/tools/accuracy_checker/tests/test_convert.py new file mode 100644 index 00000000000..d34f4959b53 --- /dev/null +++ b/tools/accuracy_checker/tests/test_convert.py @@ -0,0 +1,67 @@ +""" +Copyright (c) 2018-2024 Intel Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import threading +import warnings +from accuracy_checker.annotation_converters.convert import AtomicWriteFileHandle + +def thread_access_file(file_path, data_dict, thread_id, write_lines): + if os.path.exists(file_path): + with open(file_path, 'r') as file: + read_lines = len(file.readlines()) + # when a new thread reads a file, all lines must already be written + if read_lines != write_lines: + warn_message = f"Thread {thread_id}: Incorrect number of lines read from {file_path} ({read_lines} != {write_lines})" + warnings.warn(warn_message) + data_dict['assert'] = warn_message + else: + with AtomicWriteFileHandle(file_path, 'wt') as file: + for i in range(write_lines): + file.write(f"Thread {thread_id}:Line{i} {data_dict[thread_id]}\n") + +class TestAtomicWriteFileHandle: + + def test_multithreaded_atomic_file_write(self): + target_file_path = "test_atomic_file.txt" + threads = [] + num_threads = 10 + write_lines = 10 + data_chunks = [f"Data chunk {i}" for i in range(num_threads)] + threads_dict = {i: data_chunks[i] for i in range(len(data_chunks))} + + if os.path.exists(target_file_path): + os.remove(target_file_path) + + for i in range(num_threads): + thread = threading.Thread(target=thread_access_file, args=(target_file_path, threads_dict, i, write_lines)) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + with open(target_file_path, 'r') as file: + lines = file.readlines() + + os.remove(target_file_path) + + # check asserts passed from threads + assert 'assert' not in threads_dict.keys() , threads_dict['assert'] + + assert sum(1 for line in lines for data_chunk in data_chunks if data_chunk in line) == write_lines, f"data_chunks data not found in the {target_file_path} file"