Skip to content

Commit d2ddfc2

Browse files
committed
refactor: Make --part and --parts Option, better error handling
1 parent 3f59e0c commit d2ddfc2

3 files changed

Lines changed: 225 additions & 70 deletions

File tree

tpchgen-cli/src/main.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,15 @@ struct Cli {
9090
#[arg(short = 'T', long = "tables", value_delimiter = ',', value_parser = TableValueParser)]
9191
tables: Option<Vec<Table>>,
9292

93-
/// Number of parts to generate (manual parallel generation)
94-
#[arg(short, long, default_value_t = -1)]
95-
parts: i32,
93+
/// Number of partitions to generate (manual parallel generation)
94+
#[arg(short, long)]
95+
parts: Option<i32>,
9696

97-
/// Which part to generate (1-based, only relevant if parts > 1)
98-
#[arg(long, default_value_t = -1)]
99-
part: i32,
97+
/// Which partition to generate (1-based)
98+
///
99+
/// If not specified, generates all parts
100+
#[arg(long)]
101+
part: Option<i32>,
100102

101103
/// Output format: tbl, csv, parquet (default: tbl)
102104
#[arg(short, long, default_value = "tbl")]
@@ -254,14 +256,15 @@ macro_rules! define_generate {
254256
($FUN_NAME:ident, $TABLE:expr, $GENERATOR:ident, $TBL_SOURCE:ty, $CSV_SOURCE:ty, $PARQUET_SOURCE:ty) => {
255257
async fn $FUN_NAME(&self) -> io::Result<()> {
256258
let filename = self.output_filename($TABLE);
257-
let plan = GenerationPlan::new(
259+
let plan = GenerationPlan::try_new(
258260
&$TABLE,
259261
self.format,
260262
self.scale_factor,
261263
self.part,
262264
self.parts,
263265
self.num_threads,
264-
);
266+
)
267+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
265268
let scale_factor = self.scale_factor;
266269
info!("Writing table {} (SF={scale_factor}) to {filename}", $TABLE);
267270
debug!("Plan: {plan}");

tpchgen-cli/src/plan.rs

Lines changed: 121 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -60,59 +60,98 @@ impl GenerationPlan {
6060
/// Returns the number of parts to generate
6161
///
6262
/// cli_part and cli_part_count are passed from the CLI arguments
63-
/// to specify a particular part or number of parts to generate.
64-
pub fn new(
63+
/// to specify a particular part or number of partitions to generate.
64+
pub fn try_new(
6565
table: &Table,
6666
format: OutputFormat,
6767
scale_factor: f64,
68-
cli_part: i32,
69-
cli_part_count: i32,
68+
cli_part: Option<i32>,
69+
cli_part_count: Option<i32>,
7070
num_threads: usize,
71-
) -> Self {
71+
) -> Result<Self, String> {
7272
// If a single part is specified, split it into chunks to enable parallel generation.
73-
if cli_part != -1 || cli_part_count != -1 {
74-
// These tables are small not parameterized by part count,
75-
// so we must create only a single part.
76-
if table == &Table::Nation || table == &Table::Region {
77-
return Self {
78-
part_count: 1,
79-
part_list: vec![1],
80-
};
73+
match (cli_part, cli_part_count) {
74+
(Some(_part), None) => Err(String::from(
75+
"The --part option requires the --parts option to be set",
76+
)),
77+
(None, Some(_part_count)) => {
78+
// TODO automatically create multiple files if part_count > 1
79+
// and part is not specified
80+
Err(String::from(
81+
"The --part_count option requires the --part option to be set",
82+
))
8183
}
82-
83-
// sanity check arguments (TODO: real Errors)
84-
if cli_part < 1 || cli_part_count < 1 || cli_part > cli_part_count {
85-
panic!(
86-
"Invalid CLI part or part count. \
87-
Expect greater than 1 and cli_part <= cli_part_count. \
88-
Got: cli_part={cli_part}, cli_part_count={cli_part_count}",
89-
);
84+
(Some(part), Some(part_count)) => {
85+
Self::try_new_with_parts(table, format, scale_factor, part, part_count, num_threads)
9086
}
87+
(None, None) => Self::try_new_without_parts(table, format, scale_factor),
88+
}
89+
}
90+
91+
/// Returns a new `GenerationPlan` when partitioning is specified on the command line
92+
fn try_new_with_parts(
93+
table: &Table,
94+
// todo take into account the output format (e.g. parquet row group size)
95+
_format: OutputFormat,
96+
scale_factor: f64,
97+
cli_part: i32,
98+
cli_part_count: i32,
99+
num_threads: usize,
100+
) -> Result<Self, String> {
101+
if cli_part < 1 {
102+
return Err(format!(
103+
"Invalid --part. Expected a number greater than zero, got {cli_part}"
104+
));
105+
}
106+
if cli_part_count < 1 {
107+
return Err(format!(
108+
"Invalid --part_count. Expected a number greater than zero, got {cli_part_count}"
109+
));
110+
}
111+
if cli_part > cli_part_count {
112+
return Err(format!(
113+
"Invalid --part. Expected at most the value of --parts ({cli_part_count}), got {cli_part}"));
114+
}
91115

92-
let num_chunks = num_threads as i32;
93-
94-
// The new total number of parts is the original number of parts multiplied by the number of chunks.
95-
let new_total_parts = cli_part_count * num_chunks;
96-
97-
// The new part numbers to generate are the chunks that make up the original part.
98-
let start_part = (cli_part - 1) * num_chunks + 1;
99-
let end_part = cli_part * num_chunks;
100-
let new_parts_to_generate = (start_part..=end_part).collect();
101-
debug!(
102-
"Generating {} parts for table {:?} with scale factor {}",
103-
new_total_parts, table, scale_factor
104-
);
105-
debug!(
106-
"CLI part: {}, CLI part count: {}, num_threads: {}",
107-
cli_part, cli_part_count, num_threads
108-
);
109-
debug!("New parts to generate: {:?}", new_parts_to_generate);
110-
return Self {
111-
part_count: new_total_parts,
112-
part_list: new_parts_to_generate,
113-
};
116+
// These tables are small they are not parameterized by part count, so
117+
// we must create only a single part.
118+
if table == &Table::Nation || table == &Table::Region {
119+
return Ok(Self {
120+
part_count: 1,
121+
part_list: vec![1],
122+
});
114123
}
115124

125+
let num_chunks = num_threads as i32;
126+
127+
// The new total number of parts is the original number of parts multiplied by the number of chunks.
128+
let new_total_parts = cli_part_count * num_chunks;
129+
130+
// The new part numbers to generate are the chunks that make up the original part.
131+
let start_part = (cli_part - 1) * num_chunks + 1;
132+
let end_part = cli_part * num_chunks;
133+
let new_parts_to_generate = (start_part..=end_part).collect();
134+
debug!(
135+
"Generating {} parts for table {:?} with scale factor {}",
136+
new_total_parts, table, scale_factor
137+
);
138+
debug!(
139+
"CLI part: {}, CLI part count: {}, num_threads: {}",
140+
cli_part, cli_part_count, num_threads
141+
);
142+
debug!("New parts to generate: {:?}", new_parts_to_generate);
143+
Ok(Self {
144+
part_count: new_total_parts,
145+
part_list: new_parts_to_generate,
146+
})
147+
}
148+
149+
/// Returns a new `GenerationPlan` when no partitioning is specified on the command line
150+
fn try_new_without_parts(
151+
table: &Table,
152+
format: OutputFormat,
153+
scale_factor: f64,
154+
) -> Result<Self, String> {
116155
// Note use part=1, part_count=1 to calculate the total row count
117156
// for the table
118157
//
@@ -157,10 +196,10 @@ impl GenerationPlan {
157196
let num_parts = num_parts.try_into().unwrap();
158197
// generating all the parts
159198

160-
Self {
199+
Ok(Self {
161200
part_count: num_parts,
162201
part_list: (1..=num_parts).collect(),
163-
}
202+
})
164203
}
165204
}
166205

@@ -324,31 +363,36 @@ mod tests {
324363
}
325364

326365
#[test]
327-
#[should_panic(
328-
expected = "Invalid CLI part or part count. Expect greater than 1 and cli_part <= cli_part_count. Got: cli_part=0, cli_part_count=10"
329-
)]
330-
fn sf1_lineitem_cli_parts_invalid_small() {
366+
fn sf1_lineitem_cli_invalid_part() {
331367
Test::new()
332368
.with_table(Table::Lineitem)
333369
.with_format(OutputFormat::Tbl)
334370
.with_scale_factor(1.0)
335371
.with_cli_part(0) // part 0 of 10 (invalid)
336372
.with_cli_part_count(10)
337-
.assert(40, [13, 14, 15, 16])
373+
.assert_err("Invalid --part. Expected a number greater than zero, got 0")
338374
}
339375

340376
#[test]
341-
#[should_panic(
342-
expected = "Invalid CLI part or part count. Expect greater than 1 and cli_part <= cli_part_count. Got: cli_part=11, cli_part_count=10"
343-
)]
344377
fn sf1_lineitem_cli_parts_invalid_big() {
345378
Test::new()
346379
.with_table(Table::Lineitem)
347380
.with_format(OutputFormat::Tbl)
348381
.with_scale_factor(1.0)
349382
.with_cli_part(11) // part 11 of 10 (invalid)
350383
.with_cli_part_count(10)
351-
.assert(40, [13, 14, 15, 16])
384+
.assert_err("Invalid --part. Expected at most the value of --parts (10), got 11");
385+
}
386+
387+
#[test]
388+
fn sf1_lineitem_cli_invalid_part_count() {
389+
Test::new()
390+
.with_table(Table::Lineitem)
391+
.with_format(OutputFormat::Tbl)
392+
.with_scale_factor(1.0)
393+
.with_cli_part(1) // part 0 of 0 (invalid)
394+
.with_cli_part_count(0)
395+
.assert_err("Invalid --part_count. Expected a number greater than zero, got 0");
352396
}
353397

354398
#[test]
@@ -393,8 +437,8 @@ mod tests {
393437
table: Table,
394438
format: OutputFormat,
395439
scale_factor: f64,
396-
cli_part: i32,
397-
cli_part_count: i32,
440+
cli_part: Option<i32>,
441+
cli_part_count: Option<i32>,
398442
num_cpus: usize,
399443
}
400444

@@ -410,19 +454,34 @@ mod tests {
410454
expected_part_count: i32,
411455
expected_part_numbers: impl IntoIterator<Item = i32>,
412456
) {
413-
let plan = GenerationPlan::new(
457+
let plan = GenerationPlan::try_new(
414458
&self.table,
415459
self.format,
416460
self.scale_factor,
417461
self.cli_part,
418462
self.cli_part_count,
419463
self.num_cpus,
420-
);
464+
)
465+
.unwrap();
421466
assert_eq!(plan.part_count, expected_part_count);
422467
let expected_part_numbers: Vec<i32> = expected_part_numbers.into_iter().collect();
423468
assert_eq!(plan.part_list, expected_part_numbers);
424469
}
425470

471+
/// Assert that creating a [`GenerationPlan`] returns the specified error
472+
fn assert_err(self, expected_error: &str) {
473+
let actual_error = GenerationPlan::try_new(
474+
&self.table,
475+
self.format,
476+
self.scale_factor,
477+
self.cli_part,
478+
self.cli_part_count,
479+
self.num_cpus,
480+
)
481+
.unwrap_err();
482+
assert_eq!(actual_error, expected_error);
483+
}
484+
426485
/// Set table
427486
fn with_table(mut self, table: Table) -> Self {
428487
self.table = table;
@@ -443,13 +502,13 @@ mod tests {
443502

444503
/// Set CLI part
445504
fn with_cli_part(mut self, cli_part: i32) -> Self {
446-
self.cli_part = cli_part;
505+
self.cli_part = Some(cli_part);
447506
self
448507
}
449508

450509
/// Set CLI part count
451510
fn with_cli_part_count(mut self, cli_part_count: i32) -> Self {
452-
self.cli_part_count = cli_part_count;
511+
self.cli_part_count = Some(cli_part_count);
453512
self
454513
}
455514
}
@@ -460,8 +519,8 @@ mod tests {
460519
table: Table::Orders,
461520
format: OutputFormat::Tbl,
462521
scale_factor: 1.0,
463-
cli_part: -1,
464-
cli_part_count: -1,
522+
cli_part: None,
523+
cli_part_count: None,
465524
num_cpus: 4, // hard code 4 cores for testing
466525
}
467526
}

0 commit comments

Comments
 (0)