-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathwikipedia_splade.rs
More file actions
208 lines (169 loc) · 6.74 KB
/
wikipedia_splade.rs
File metadata and controls
208 lines (169 loc) · 6.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
use std::path::PathBuf;
use anyhow::{Context, Result};
use arrow::{
array::AsArray,
datatypes::{Float32Type, Int32Type},
record_batch::RecordBatch,
};
use futures::{stream, Stream, StreamExt, TryStreamExt};
use hf_hub::api::tokio::Api;
use parquet::arrow::ParquetRecordBatchStreamBuilder;
use sprs::CsVec;
use tokio::fs::File;
const SPLADE_VOCAB_SIZE: usize = 30522;
pub struct WikipediaSplade {
pub train_paths: Vec<PathBuf>,
pub test_path: PathBuf,
}
impl WikipediaSplade {
pub async fn init() -> Result<Self> {
let api = Api::new()?;
let dataset = api.dataset("Sicheng-Chroma/wikipedia-en-splade-bge".to_string());
let mut train_paths = Vec::new();
for shard_idx in 0..7 {
let train_shard = format!("train/train-{shard_idx:05}-of-00007.parquet");
let train_path = dataset.get(&train_shard).await?;
train_paths.push(train_path);
}
// Download test queries
let test_path = dataset.get("test/test-00000-of-00001.parquet").await?;
Ok(Self {
train_paths,
test_path,
})
}
pub async fn documents(&self) -> Result<impl Stream<Item = Result<SparseDocument>> + '_> {
let mut shard_streams = Vec::new();
for shard_path in &self.train_paths {
let file = File::open(shard_path).await?;
let shard_stream = ParquetRecordBatchStreamBuilder::new(file).await?.build()?;
shard_streams.push(shard_stream);
}
Ok(stream::iter(shard_streams)
.flatten()
.map(|res| {
res.map_err(Into::into).and_then(|batch| {
Self::batch_to_documents(batch)
.map(|docs| stream::iter(docs.into_iter().map(Ok)))
})
})
.try_flatten())
}
// Helper to convert a batch to documents
// Returns Vec for now since we need to access columns multiple times
fn batch_to_documents(batch: RecordBatch) -> Result<Vec<SparseDocument>> {
let texts = batch
.column_by_name("text")
.context("Missing text column")?
.as_string::<i32>();
let titles = batch
.column_by_name("title")
.context("Missing title column")?
.as_string::<i32>();
let urls = batch
.column_by_name("url")
.context("Missing url column")?
.as_string::<i32>();
let sparse_indices = batch
.column_by_name("sparse_embedding_indices")
.context("Missing sparse_embedding_indices column")?
.as_list::<i32>();
let sparse_values = batch
.column_by_name("sparse_embedding_values")
.context("Missing sparse_embedding_values column")?
.as_list::<i32>();
let mut documents = Vec::with_capacity(batch.num_rows());
for i in 0..batch.num_rows() {
let text = texts.value(i).to_string();
let title = titles.value(i).to_string();
let url = urls.value(i).to_string();
let indices = sparse_indices.value(i);
let values = sparse_values.value(i);
let indices_array = indices.as_primitive::<Int32Type>();
let values_array = values.as_primitive::<Float32Type>();
let mut sparse_indices_vec = Vec::with_capacity(indices_array.len());
let mut sparse_values_vec = Vec::with_capacity(values_array.len());
for j in 0..indices_array.len() {
sparse_indices_vec.push(indices_array.value(j) as usize);
sparse_values_vec.push(values_array.value(j));
}
let sparse_vector =
CsVec::new(SPLADE_VOCAB_SIZE, sparse_indices_vec, sparse_values_vec);
documents.push(SparseDocument {
doc_id: url.clone(),
url,
title,
body: text,
sparse_vector,
});
}
Ok(documents)
}
pub async fn queries(&self) -> Result<Vec<SparseQuery>> {
// Use the already downloaded test file
let file = File::open(&self.test_path).await?;
let stream = ParquetRecordBatchStreamBuilder::new(file).await?.build()?;
let batches = stream.try_collect::<Vec<_>>().await?;
let mut queries = Vec::new();
for batch in batches {
// Try to get query_id if it exists, otherwise generate it
let query_ids = batch
.column_by_name("query_id")
.map(|col| col.as_string::<i32>());
let query_texts = batch
.column_by_name("text")
.or_else(|| batch.column_by_name("query"))
.context("Missing text/query column")?
.as_string::<i32>();
let sparse_indices = batch
.column_by_name("sparse_embedding_indices")
.context("Missing sparse_embedding_indices column")?
.as_list::<i32>();
let sparse_values = batch
.column_by_name("sparse_embedding_values")
.context("Missing sparse_embedding_values column")?
.as_list::<i32>();
for i in 0..batch.num_rows() {
let query_id = if let Some(ids) = query_ids {
ids.value(i).to_string()
} else {
format!("query_{}", i)
};
let query_text = query_texts.value(i).to_string();
// Extract sparse vector indices and values
let indices_array = sparse_indices.value(i);
let values_array = sparse_values.value(i);
let indices = indices_array.as_primitive::<Int32Type>();
let values = values_array.as_primitive::<Float32Type>();
// Build sparse vector
let mut sparse_indices_vec = Vec::new();
let mut sparse_values_vec = Vec::new();
for j in 0..indices.len() {
sparse_indices_vec.push(indices.value(j) as usize);
sparse_values_vec.push(values.value(j));
}
let sparse_vector = CsVec::new(30522, sparse_indices_vec, sparse_values_vec);
queries.push(SparseQuery {
query_id,
text: query_text,
sparse_vector,
});
}
}
Ok(queries)
}
}
#[derive(Debug, Clone)]
pub struct SparseDocument {
pub doc_id: String,
pub url: String,
pub title: String,
pub body: String,
pub sparse_vector: CsVec<f32>,
}
#[derive(Debug, Clone)]
pub struct SparseQuery {
pub query_id: String,
pub text: String,
pub sparse_vector: CsVec<f32>,
}