-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlib.rs
More file actions
265 lines (232 loc) · 8.48 KB
/
lib.rs
File metadata and controls
265 lines (232 loc) · 8.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
/*
Copyright 2024-2025 The Spice.ai OSS Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
//! Checkpoint object store — upload / download checkpoint artefacts to S3.
//!
//! ## Layout
//!
//! ```text
//! s3://{bucket}/{prefix}/{scenario}/checkpoints/{checkpoint_idx}/{query_idx}.parquet
//! s3://{bucket}/{prefix}/checkpoints.json ← manifest
//! ```
//!
//! The manifest (`checkpoints.json`) contains metadata for every scenario
//! that has been checkpointed under the given prefix.
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use object_store::aws::AmazonS3Builder;
use object_store::path::Path as ObjectPath;
use object_store::{ObjectStore, PutPayload};
use serde::{Deserialize, Serialize};
/// Top-level manifest persisted as `{prefix}/checkpoints.json`.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CheckpointManifest {
/// Map from scenario name to its metadata.
pub scenarios: HashMap<String, ScenarioCheckpoint>,
}
/// Per-scenario metadata stored inside the manifest.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScenarioCheckpoint {
/// Total number of checkpoint snapshots stored for this scenario.
pub num_checkpoints: usize,
/// Number of query results stored in each checkpoint snapshot.
pub num_queries: usize,
}
/// S3‑backed store for uploading and downloading checkpoint artefacts.
pub struct CheckpointStore {
store: Arc<dyn ObjectStore>,
#[allow(dead_code)]
bucket: String,
prefix: String,
}
impl CheckpointStore {
/// Build a new `CheckpointStore` from raw S3 connection parameters.
///
/// `prefix` may be empty; it is the key‑prefix shared by all scenarios
/// (e.g. `"run-42"`).
pub fn new(
bucket: &str,
prefix: &str,
region: Option<&str>,
endpoint: Option<&str>,
) -> anyhow::Result<Self> {
let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket);
if let Some(region) = region {
builder = builder.with_region(region);
}
if let Some(endpoint) = endpoint
&& !endpoint.is_empty()
{
builder = builder.with_endpoint(endpoint);
if endpoint.starts_with("http://") {
builder = builder.with_allow_http(true);
}
}
let store = Arc::new(builder.build()?);
Ok(Self {
store,
bucket: bucket.to_owned(),
prefix: prefix.to_owned(),
})
}
fn object_path(&self, suffix: &str) -> ObjectPath {
if self.prefix.is_empty() {
ObjectPath::from(suffix.to_owned())
} else {
ObjectPath::from(format!("{}/{suffix}", self.prefix))
}
}
fn manifest_path(&self) -> ObjectPath {
self.object_path("checkpoints.json")
}
fn checkpoint_parquet_path(
&self,
scenario: &str,
checkpoint_idx: usize,
query_idx: usize,
) -> ObjectPath {
self.object_path(&format!(
"{scenario}/checkpoints/{checkpoint_idx}/{query_idx}.parquet"
))
}
/// Upload all checkpoint parquet files from `local_checkpoint_dir` to S3,
/// then update (merge into) the manifest at `{prefix}/checkpoints.json`.
///
/// The local directory is expected to have the layout produced by the
/// checkpointer binary:
///
/// ```text
/// {local_checkpoint_dir}/
/// 0/
/// 0.parquet
/// 1.parquet
/// 1/
/// 0.parquet
/// ...
/// ```
pub async fn upload_checkpoints(
&self,
scenario: &str,
local_checkpoint_dir: &Path,
) -> anyhow::Result<()> {
if !local_checkpoint_dir.is_dir() {
anyhow::bail!(
"Checkpoint directory does not exist: {}",
local_checkpoint_dir.display()
);
}
let mut num_checkpoints: usize = 0;
let mut num_queries: usize = 0;
// Iterate over checkpoint index directories (0, 1, 2, …).
let mut checkpoint_dirs: Vec<_> = std::fs::read_dir(local_checkpoint_dir)?
.filter_map(Result::ok)
.filter(|e| e.path().is_dir())
.collect();
checkpoint_dirs.sort_by_key(|e| e.file_name());
for checkpoint_entry in &checkpoint_dirs {
let checkpoint_idx: usize = checkpoint_entry
.file_name()
.to_string_lossy()
.parse()
.unwrap_or(0);
let mut query_files: Vec<_> = std::fs::read_dir(checkpoint_entry.path())?
.filter_map(Result::ok)
.filter(|e| e.path().extension().is_some_and(|ext| ext == "parquet"))
.collect();
query_files.sort_by_key(|e| e.file_name());
for qf in &query_files {
let q_idx: usize = qf
.path()
.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let bytes = std::fs::read(qf.path())?;
let dest = self.checkpoint_parquet_path(scenario, checkpoint_idx, q_idx);
tracing::info!(
scenario,
checkpoint = checkpoint_idx,
query = q_idx,
dest = %dest,
"Uploading checkpoint parquet"
);
self.store.put(&dest, PutPayload::from(bytes)).await?;
if q_idx >= num_queries {
num_queries = q_idx + 1;
}
}
num_checkpoints = checkpoint_idx + 1;
}
// Merge into manifest.
let mut manifest = self.download_manifest().await.unwrap_or_default();
manifest.scenarios.insert(
scenario.to_owned(),
ScenarioCheckpoint {
num_checkpoints,
num_queries,
},
);
self.put_manifest(&manifest).await?;
tracing::info!(
scenario,
num_checkpoints,
num_queries,
"Checkpoint upload complete"
);
Ok(())
}
/// Upload (overwrite) the manifest JSON.
async fn put_manifest(&self, manifest: &CheckpointManifest) -> anyhow::Result<()> {
let json = serde_json::to_vec_pretty(manifest)?;
self.store
.put(&self.manifest_path(), PutPayload::from(json))
.await?;
Ok(())
}
/// Download and deserialise the manifest from S3.
///
/// Returns `Ok(manifest)` or an error if the manifest does not exist or
/// cannot be parsed.
pub async fn download_manifest(&self) -> anyhow::Result<CheckpointManifest> {
let data = self.store.get(&self.manifest_path()).await?.bytes().await?;
let manifest: CheckpointManifest = serde_json::from_slice(&data)?;
Ok(manifest)
}
/// Download all checkpoint parquet files for `scenario` into
/// `local_dir/{checkpoint_idx}/{query_idx}.parquet`.
pub async fn download_checkpoints(
&self,
scenario: &str,
info: &ScenarioCheckpoint,
local_dir: &Path,
) -> anyhow::Result<()> {
for checkpoint_idx in 0..info.num_checkpoints {
let checkpoint_dir = local_dir.join(checkpoint_idx.to_string());
std::fs::create_dir_all(&checkpoint_dir)?;
for q_idx in 0..info.num_queries {
let remote = self.checkpoint_parquet_path(scenario, checkpoint_idx, q_idx);
let data = self.store.get(&remote).await?.bytes().await?;
let local_path = checkpoint_dir.join(format!("{q_idx}.parquet"));
std::fs::write(&local_path, &data)?;
tracing::info!(
scenario,
checkpoint = checkpoint_idx,
query = q_idx,
path = %local_path.display(),
"Downloaded checkpoint parquet"
);
}
}
Ok(())
}
}