Skip to content

Commit b6e9a61

Browse files
committed
Parallelize codegen
1 parent 4197ccb commit b6e9a61

File tree

1 file changed

+74
-42
lines changed

1 file changed

+74
-42
lines changed

codegen/main.rs

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use std::{
66
io::Write,
77
path::Path,
88
process::{Command, Stdio},
9+
sync::mpsc,
10+
thread,
911
};
1012

1113
use crate::{
@@ -164,52 +166,82 @@ fn main() -> eyre::Result<()> {
164166
let mut python_files = Vec::with_capacity(declarations.declarations.len() + 1);
165167
python_files.push(String::from("mod.rs"));
166168

169+
let items: Vec<_> = declarations.declarations.keys().collect();
170+
167171
// generate custom code
168-
for (path, item) in &declarations.declarations {
169-
let item_name = path.0.last().unwrap().as_str();
170-
if item_name == "Float" {
171-
// Special case for Float (we always inline Float into Py<PyFloat>)
172-
continue;
172+
thread::scope(|s| {
173+
let (tx, rx) = mpsc::channel();
174+
175+
let num_codegen_threads = thread::available_parallelism().unwrap().get();
176+
let num_items_per_thread = items.len().div_ceil(num_codegen_threads).max(8);
177+
let num_threads = items.len().div_ceil(num_items_per_thread);
178+
179+
for i in 0..num_threads {
180+
let start_num = num_items_per_thread * i;
181+
let end_num = (start_num + num_items_per_thread).min(items.len());
182+
183+
let declarations = &declarations.declarations;
184+
let items = &items[start_num..end_num];
185+
let tx = tx.clone();
186+
187+
s.spawn(move || {
188+
for &path in items {
189+
let item = &declarations[path];
190+
let item_name = path.0.last().unwrap().as_str();
191+
if item_name == "Float" {
192+
// Special case for Float (we always inline Float into Py<PyFloat>)
193+
continue;
194+
}
195+
196+
let mut file_name = camel_to_snake(item_name);
197+
198+
let file_contents = match &item.kind {
199+
DeclarationKind::Table(info) => {
200+
let bind_gen =
201+
TableBindGenerator::new(item_name, &info.fields, declarations);
202+
bind_gen.generate_binds()
203+
}
204+
DeclarationKind::Struct(info) => {
205+
let bind_gen =
206+
StructBindGenerator::new(item_name, &info.fields, declarations);
207+
bind_gen.generate_binds()
208+
}
209+
DeclarationKind::Enum(info) => {
210+
let bind_gen = EnumBindGenerator::new(item_name, &info.variants);
211+
bind_gen.generate_binds()
212+
}
213+
DeclarationKind::Union(info) => {
214+
let bind_gen = UnionBindGenerator::new(item_name, &info.variants);
215+
bind_gen.generate_binds()
216+
}
217+
DeclarationKind::RpcService(_) => unimplemented!(),
218+
};
219+
220+
let mod_lines: String =
221+
["mod ", &file_name, ";\n", "pub use ", &file_name, "::*;\n"]
222+
.into_iter()
223+
.collect();
224+
file_name.push_str(".rs");
225+
226+
fs::write(
227+
python_folder.join(&file_name),
228+
format_string(&file_contents.join("\n")).unwrap(),
229+
)
230+
.unwrap();
231+
232+
tx.send((item_name, mod_lines, file_name)).unwrap();
233+
}
234+
});
173235
}
174236

175-
let mut file_name = camel_to_snake(item_name);
176-
177-
let file_contents = match &item.kind {
178-
DeclarationKind::Table(info) => {
179-
let bind_gen =
180-
TableBindGenerator::new(item_name, &info.fields, &declarations.declarations);
181-
bind_gen.generate_binds()
182-
}
183-
DeclarationKind::Struct(info) => {
184-
let bind_gen =
185-
StructBindGenerator::new(item_name, &info.fields, &declarations.declarations);
186-
bind_gen.generate_binds()
187-
}
188-
DeclarationKind::Enum(info) => {
189-
let bind_gen = EnumBindGenerator::new(item_name, &info.variants);
190-
bind_gen.generate_binds()
191-
}
192-
DeclarationKind::Union(info) => {
193-
let bind_gen = UnionBindGenerator::new(item_name, &info.variants);
194-
bind_gen.generate_binds()
195-
}
196-
DeclarationKind::RpcService(_) => unimplemented!(),
197-
};
198-
199-
class_names.push(item_name);
200-
python_mod.push(
201-
["mod ", &file_name, ";\n", "pub use ", &file_name, "::*;\n"]
202-
.into_iter()
203-
.collect(),
204-
);
205-
file_name.push_str(".rs");
237+
drop(tx);
206238

207-
fs::write(
208-
python_folder.join(&file_name),
209-
format_string(&file_contents.join("\n"))?,
210-
)?;
211-
python_files.push(file_name);
212-
}
239+
for (class_name, mod_lines, file_name) in rx.iter() {
240+
class_names.push(class_name);
241+
python_mod.push(mod_lines);
242+
python_files.push(file_name);
243+
}
244+
});
213245

214246
// remove old files for types that don't exist anymore
215247
for item in fs::read_dir(python_folder)?.flatten() {

0 commit comments

Comments
 (0)