|
| 1 | +""" |
| 2 | +Fix Embedding Model Cache Path Mismatch |
| 3 | +
|
| 4 | +CRITICAL: This script fixes the cache path mismatch issue where: |
| 5 | +- Model is downloaded in HuggingFace format: /app/hf_cache/models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2 |
| 6 | +- But SentenceTransformer looks for: /app/hf_cache/sentence_transformers/paraphrase-multilingual-MiniLM-L12-v2 |
| 7 | +
|
| 8 | +Solution: Create symlink or copy from HuggingFace format to sentence-transformers format |
| 9 | +""" |
| 10 | + |
| 11 | +import os |
| 12 | +import shutil |
| 13 | +import logging |
| 14 | +from pathlib import Path |
| 15 | + |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +def fix_embedding_model_cache(model_name: str = "paraphrase-multilingual-MiniLM-L12-v2") -> bool: |
| 20 | + """ |
| 21 | + Fix embedding model cache path mismatch. |
| 22 | + |
| 23 | + Args: |
| 24 | + model_name: Name of the model to fix |
| 25 | + |
| 26 | + Returns: |
| 27 | + True if fix was successful or not needed, False if failed |
| 28 | + """ |
| 29 | + try: |
| 30 | + cache_base = Path("/app/hf_cache") |
| 31 | + if not cache_base.exists(): |
| 32 | + logger.warning(f"⚠️ Cache base directory does not exist: {cache_base}") |
| 33 | + return False |
| 34 | + |
| 35 | + # Model name variations |
| 36 | + model_name_safe = model_name.replace("/", "_") |
| 37 | + model_name_hf = model_name.replace("/", "--") |
| 38 | + |
| 39 | + # HuggingFace format path (where model might be downloaded) |
| 40 | + hf_paths = [ |
| 41 | + cache_base / f"models--sentence-transformers--{model_name_hf}", |
| 42 | + cache_base / "hub" / f"models--sentence-transformers--{model_name_hf}", |
| 43 | + cache_base / f"models--{model_name_hf}", |
| 44 | + cache_base / "hub" / f"models--{model_name_hf}", |
| 45 | + ] |
| 46 | + |
| 47 | + # Sentence-transformers format path (where SentenceTransformer looks) |
| 48 | + st_path = cache_base / "sentence_transformers" / model_name_safe |
| 49 | + |
| 50 | + # Find HuggingFace format cache |
| 51 | + hf_source = None |
| 52 | + for hf_path in hf_paths: |
| 53 | + if hf_path.exists(): |
| 54 | + # Verify it has model files |
| 55 | + if any(hf_path.rglob("*.json")) or any(hf_path.rglob("*.bin")) or any(hf_path.rglob("*.safetensors")): |
| 56 | + hf_source = hf_path |
| 57 | + logger.info(f"✅ Found HuggingFace format cache: {hf_source}") |
| 58 | + break |
| 59 | + |
| 60 | + if not hf_source: |
| 61 | + logger.info("ℹ️ HuggingFace format cache not found - model may not be downloaded yet") |
| 62 | + return True # Not an error, just not downloaded yet |
| 63 | + |
| 64 | + # Check if sentence-transformers format already exists |
| 65 | + if st_path.exists(): |
| 66 | + logger.info(f"✅ Sentence-transformers format cache already exists: {st_path}") |
| 67 | + return True |
| 68 | + |
| 69 | + # Create sentence_transformers directory |
| 70 | + st_path.parent.mkdir(parents=True, exist_ok=True) |
| 71 | + |
| 72 | + # CRITICAL: For HuggingFace format, we need to extract the actual model files |
| 73 | + # HuggingFace cache structure: models--{name}/snapshots/{hash}/model files |
| 74 | + # Sentence-transformers expects: sentence_transformers/{name}/model files directly |
| 75 | + |
| 76 | + # Check if it's HuggingFace format with snapshots |
| 77 | + snapshots_dir = hf_source / "snapshots" |
| 78 | + if snapshots_dir.exists(): |
| 79 | + # Find the latest snapshot |
| 80 | + snapshots = sorted(snapshots_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True) |
| 81 | + if snapshots: |
| 82 | + latest_snapshot = snapshots[0] |
| 83 | + logger.info(f"📦 Found HuggingFace snapshot: {latest_snapshot}") |
| 84 | + |
| 85 | + # Copy model files from snapshot to sentence-transformers format |
| 86 | + logger.info(f"📦 Copying model files from HuggingFace format to sentence-transformers format...") |
| 87 | + logger.info(f" Source: {latest_snapshot}") |
| 88 | + logger.info(f" Destination: {st_path}") |
| 89 | + |
| 90 | + # Copy all files from snapshot |
| 91 | + shutil.copytree(latest_snapshot, st_path, dirs_exist_ok=True) |
| 92 | + logger.info(f"✅ Successfully copied model files to: {st_path}") |
| 93 | + return True |
| 94 | + else: |
| 95 | + logger.warning(f"⚠️ No snapshots found in: {snapshots_dir}") |
| 96 | + else: |
| 97 | + # Direct model files (not HuggingFace snapshot format) |
| 98 | + # Try to copy or symlink |
| 99 | + logger.info(f"📦 Model files are in direct format, creating symlink...") |
| 100 | + logger.info(f" Source: {hf_source}") |
| 101 | + logger.info(f" Destination: {st_path}") |
| 102 | + |
| 103 | + try: |
| 104 | + # Try symlink first (more efficient) |
| 105 | + if not st_path.exists(): |
| 106 | + os.symlink(hf_source, st_path) |
| 107 | + logger.info(f"✅ Created symlink: {st_path} -> {hf_source}") |
| 108 | + return True |
| 109 | + except OSError as e: |
| 110 | + # Symlink failed (might not be supported on all systems) |
| 111 | + logger.warning(f"⚠️ Symlink failed: {e}, trying copy instead...") |
| 112 | + shutil.copytree(hf_source, st_path, dirs_exist_ok=True) |
| 113 | + logger.info(f"✅ Copied model files to: {st_path}") |
| 114 | + return True |
| 115 | + |
| 116 | + return False |
| 117 | + |
| 118 | + except Exception as e: |
| 119 | + logger.error(f"❌ Failed to fix embedding cache: {e}") |
| 120 | + return False |
| 121 | + |
| 122 | + |
| 123 | +if __name__ == "__main__": |
| 124 | + # Setup logging |
| 125 | + logging.basicConfig(level=logging.INFO) |
| 126 | + |
| 127 | + # Fix cache |
| 128 | + success = fix_embedding_model_cache() |
| 129 | + if success: |
| 130 | + print("✅ Embedding cache fix completed successfully") |
| 131 | + else: |
| 132 | + print("❌ Embedding cache fix failed - check logs for details") |
| 133 | + |
0 commit comments