-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_parser.py
More file actions
77 lines (66 loc) · 2.4 KB
/
Copy pathdataset_parser.py
File metadata and controls
77 lines (66 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import struct
from typing import Tuple, Optional
import numpy as np
R_MNIST_DEFAULT = 2000.0
R_SIFT_DEFAULT = 2800.0
def load_mnist_images(path: str) -> np.ndarray:
"""
Read idx3-ubyte MNIST images into numpy array [n, d] (float32).
"""
with open(path, "rb") as f:
header = f.read(16)
if len(header) < 16:
raise ValueError(f"Bad MNIST file: {path}")
magic, nimg, nrows, ncols = struct.unpack(">IIII", header)
if magic != 2051:
raise ValueError(
f"Bad MNIST magic number (expected 2051, got {magic}) in {path}"
)
d = nrows * ncols
data = f.read()
arr = np.frombuffer(data, dtype=np.uint8)
if arr.size % d != 0:
raise ValueError(f"MNIST data size not divisible by {d} in {path}")
arr = arr.reshape(-1, d).astype(np.float32)
return arr
def load_sift_fvecs(path: str) -> np.ndarray:
"""
Read SIFT .fvecs file into numpy array [n, d] (float32).
Format: int32 d, followed by d float32 (little endian), repeated.
"""
vectors = []
with open(path, "rb") as f:
while True:
dim_bytes = f.read(4)
if not dim_bytes:
break # EOF
if len(dim_bytes) < 4:
raise ValueError(f"Truncated fvecs file (dim) in {path}")
(d,) = struct.unpack("<i", dim_bytes)
vec_bytes = f.read(4 * d)
if len(vec_bytes) < 4 * d:
raise ValueError(f"Truncated fvecs file (vector) in {path}")
v = np.frombuffer(vec_bytes, dtype="<f4") # little-endian float32
vectors.append(v.copy())
if not vectors:
raise ValueError(f"No vectors read from {path}")
return np.stack(vectors, axis=0)
def load_dataset(
data_path: str,
query_path: Optional[str],
dtype: str,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Return (data, query) as numpy arrays [n, d] (float32).
dtype: 'mnist' ή 'sift'.
"""
dtype = dtype.lower()
if dtype == "mnist":
data = load_mnist_images(data_path)
queries = load_mnist_images(query_path) if query_path is not None else None
elif dtype == "sift":
data = load_sift_fvecs(data_path)
queries = load_sift_fvecs(query_path) if query_path is not None else None
else:
raise ValueError("dtype must be 'mnist' or 'sift'")
return data, queries