Skip to content

Commit f72aa09

Browse files
committed
[fix] Add filelock for prompt embeds
1 parent c4121ac commit f72aa09

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

src/cogkit/datasets/utils.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import cv2
77
import numpy as np
88
import torch
9+
from filelock import FileLock
910
from PIL import Image
1011
from safetensors.torch import load_file, save_file
1112
from torchvision.io import VideoReader
@@ -187,24 +188,26 @@ def get_prompt_embedding(
187188

188189
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
189190
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
190-
191-
if prompt_embedding_path.exists():
192-
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
193-
logger.debug(
194-
f"Loaded prompt embedding from {prompt_embedding_path}",
195-
main_process_only=False,
196-
)
197-
else:
198-
prompt_embedding = encode_fn(prompt)
199-
assert prompt_embedding.ndim == 2
200-
# shape of prompt_embedding: [seq_len, hidden_size]
201-
202-
prompt_embedding = prompt_embedding.to("cpu")
203-
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
204-
logger.info(
205-
f"Saved prompt embedding to {prompt_embedding_path}",
206-
main_process_only=False,
207-
)
191+
lock = FileLock(str(prompt_embedding_path) + ".lock")
192+
193+
with lock:
194+
if prompt_embedding_path.exists():
195+
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
196+
logger.debug(
197+
f"Loaded prompt embedding from {prompt_embedding_path}",
198+
main_process_only=False,
199+
)
200+
else:
201+
prompt_embedding = encode_fn(prompt)
202+
assert prompt_embedding.ndim == 2
203+
# shape of prompt_embedding: [seq_len, hidden_size]
204+
205+
prompt_embedding = prompt_embedding.to("cpu")
206+
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
207+
logger.info(
208+
f"Saved prompt embedding to {prompt_embedding_path}",
209+
main_process_only=False,
210+
)
208211

209212
return prompt_embedding
210213

0 commit comments

Comments
 (0)