Skip to content

Pseudo atomic annotation file write #3953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
import warnings
import platform
import sys

import os
import copy
import json
import tempfile
from pathlib import Path
import pickle # nosec B403 # disable import-pickle check
from argparse import ArgumentParser
Expand Down Expand Up @@ -327,7 +328,7 @@ def save_annotation(annotation, meta, annotation_file, meta_file, dataset_config
annotation_dir = annotation_file.parent
if not annotation_dir.exists():
annotation_dir.mkdir(parents=True)
with annotation_file.open('wb') as file:
with AtomicWriteFileHandle(annotation_file,'wb') as file:
if conversion_meta:
pickle.dump(conversion_meta, file)
for representation in annotation:
Expand All @@ -337,7 +338,7 @@ def save_annotation(annotation, meta, annotation_file, meta_file, dataset_config
meta_dir = meta_file.parent
if not meta_dir.exists():
meta_dir.mkdir(parents=True)
with meta_file.open('wt') as file:
with AtomicWriteFileHandle(meta_file, 'wt') as file:
json.dump(meta, file)


Expand Down Expand Up @@ -409,3 +410,38 @@ def analyze_dataset(annotations, metadata):
else:
metadata = {'data_analysis': data_analysis}
return metadata

class AtomicWriteFileHandle:
"""Ensure the file is written once in case of multi processes or threads."""

def __init__(self, file_path, open_mode):
self.target_path = file_path
self.mode = open_mode

self.temp_fd, self.temp_path = tempfile.mkstemp(dir=os.path.dirname(file_path))
self.temp_file = os.fdopen(self.temp_fd, open_mode)

def write(self, data):
self.temp_file.write(data)

def writelines(self, lines):
self.temp_file.writelines(lines)

def close(self):
if not self.temp_file.closed:
self.temp_file.close()
if not os.path.exists(self.target_path):
os.rename(self.temp_path, self.target_path)
else:
os.remove(self.temp_path)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

# Mimic other file object methods as needed
def __getattr__(self, item):
"""Delegate attribute access to the underlying temporary file object."""
return getattr(self.temp_file, item)
67 changes: 67 additions & 0 deletions tools/accuracy_checker/tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Copyright (c) 2018-2024 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import threading
import warnings
from accuracy_checker.annotation_converters.convert import AtomicWriteFileHandle

def thread_access_file(file_path, data_dict, thread_id, write_lines):
if os.path.exists(file_path):
with open(file_path, 'r') as file:
read_lines = len(file.readlines())
# when a new thread reads a file, all lines must already be written
if read_lines != write_lines:
warn_message = f"Thread {thread_id}: Incorrect number of lines read from {file_path} ({read_lines} != {write_lines})"
warnings.warn(warn_message)
data_dict['assert'] = warn_message
else:
with AtomicWriteFileHandle(file_path, 'wt') as file:
for i in range(write_lines):
file.write(f"Thread {thread_id}:Line{i} {data_dict[thread_id]}\n")

class TestAtomicWriteFileHandle:

def test_multithreaded_atomic_file_write(self):
target_file_path = "test_atomic_file.txt"
threads = []
num_threads = 10
write_lines = 10
data_chunks = [f"Data chunk {i}" for i in range(num_threads)]
threads_dict = {i: data_chunks[i] for i in range(len(data_chunks))}

if os.path.exists(target_file_path):
os.remove(target_file_path)

for i in range(num_threads):
thread = threading.Thread(target=thread_access_file, args=(target_file_path, threads_dict, i, write_lines))
threads.append(thread)

for thread in threads:
thread.start()

for thread in threads:
thread.join()

with open(target_file_path, 'r') as file:
lines = file.readlines()

os.remove(target_file_path)

# check asserts passed from threads
assert 'assert' not in threads_dict.keys() , threads_dict['assert']

assert sum(1 for line in lines for data_chunk in data_chunks if data_chunk in line) == write_lines, f"data_chunks data not found in the {target_file_path} file"
Loading