Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benches/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn bench(c: &mut Criterion) {
gen_sync: true,
gen_async: false,
derive_ser: true,
config: Default::default(),
},
)
.unwrap()
Expand All @@ -32,6 +33,7 @@ fn bench(c: &mut Criterion) {
gen_sync: true,
gen_async: false,
derive_ser: true,
config: Default::default(),
},
)
.unwrap()
Expand Down
5 changes: 1 addition & 4 deletions benches/execution/diesel_benches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ pub fn bench_insert(b: &mut Bencher, conn: &mut PgConnection, size: usize) {
};
let insert = &insert;

b.iter(|| {
let insert = insert;
insert(conn)
})
b.iter(|| insert(conn))
}

pub fn loading_associations_sequentially(b: &mut Bencher, conn: &mut PgConnection) {
Expand Down
4 changes: 4 additions & 0 deletions crates/cornucopia/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ heck = "0.4.0"

# Order-preserving map to work around borrowing issues
indexmap = "2.0.2"

# Config handling
serde = { version = "1.0.203", features = ["derive"] }
toml = "0.8.14"
62 changes: 57 additions & 5 deletions crates/cornucopia/src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use std::path::PathBuf;
use miette::Diagnostic;
use std::{fs, path::PathBuf};
use thiserror::Error as ThisError;

use clap::{Parser, Subcommand};

use crate::{conn, container, error::Error, generate_live, generate_managed, CodegenSettings};
use crate::{
config::Config, conn, container, error::Error, generate_live, generate_managed, CodegenSettings,
};

/// Command line interface to interact with Cornucopia SQL.
#[derive(Parser, Debug)]
Expand All @@ -28,6 +32,13 @@ struct Args {
/// Derive serde's `Serialize` trait for generated types.
#[clap(long)]
serialize: bool,
/// The location of the configuration file.
#[clap(short, long, default_value = default_config_path())]
config: PathBuf,
}

const fn default_config_path() -> &'static str {
"cornucopia.toml"
}

#[derive(Debug, Subcommand)]
Expand All @@ -44,8 +55,26 @@ enum Action {
},
}

/// Enumeration of the errors reported by the CLI.
#[derive(ThisError, Debug, Diagnostic)]
pub enum CliError {
/// An error occurred while loading the configuration file.
#[error("Could not load config `{path}`: ({err})")]
MissingConfig { path: String, err: std::io::Error },
/// An error occurred while parsing the configuration file.
#[error("Could not parse config `{path}`: ({err})")]
ConfigContents {
path: String,
err: Box<dyn std::error::Error + Send + Sync>,
},
/// An error occurred while running the CLI.
#[error(transparent)]
#[diagnostic(transparent)]
Internal(#[from] Error),
}

// Main entrypoint of the CLI. Parses the args and calls the appropriate routines.
pub fn run() -> Result<(), Error> {
pub fn run() -> Result<(), CliError> {
let Args {
podman,
queries_path,
Expand All @@ -54,17 +83,40 @@ pub fn run() -> Result<(), Error> {
sync,
r#async,
serialize,
config,
} = Args::parse();

let config = match fs::read_to_string(config.as_path()) {
Ok(contents) => match toml::from_str(&contents) {
Ok(config) => config,
Err(err) => {
return Err(CliError::ConfigContents {
path: config.to_string_lossy().into_owned(),
err: err.into(),
});
}
},
Err(err) => {
if config.as_path().as_os_str() != default_config_path() {
return Err(CliError::MissingConfig {
path: config.to_string_lossy().into_owned(),
err,
});
} else {
Config::default()
}
}
};
let settings = CodegenSettings {
gen_async: r#async || !sync,
gen_sync: sync,
derive_ser: serialize,
config,
};

match action {
Action::Live { url } => {
let mut client = conn::from_url(&url)?;
let mut client = conn::from_url(&url).map_err(|e| CliError::Internal(e.into()))?;
generate_live(&mut client, &queries_path, Some(&destination), settings)?;
}
Action::Schema { schema_files } => {
Expand All @@ -77,7 +129,7 @@ pub fn run() -> Result<(), Error> {
settings,
) {
container::cleanup(podman).ok();
return Err(e);
return Err(CliError::Internal(e));
}
}
};
Expand Down
12 changes: 12 additions & 0 deletions crates/cornucopia/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//! Configuration for Cornucopia.

use std::collections::HashMap;

use serde::Deserialize;

/// Configuration for Cornucopia.
#[derive(Clone, Deserialize, Default, Debug)]
pub struct Config {
/// Contains a map of what given type should map to.
pub custom_type_map: HashMap<String, String>,
}
9 changes: 6 additions & 3 deletions crates/cornucopia/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod cli;
mod codegen;
mod config;
mod error;
mod load_schema;
mod parser;
Expand All @@ -16,6 +17,7 @@ pub mod container;

use std::path::Path;

use config::Config;
use postgres::Client;

use codegen::generate as generate_internal;
Expand All @@ -31,11 +33,12 @@ pub use error::Error;
pub use load_schema::load_schema;

/// Struct containing the settings for code generation.
#[derive(Clone, Copy)]
#[derive(Clone)]
pub struct CodegenSettings {
pub gen_async: bool,
pub gen_sync: bool,
pub derive_ser: bool,
pub config: Config,
}

/// Generates Rust queries from PostgreSQL queries located at `queries_path`,
Expand All @@ -54,7 +57,7 @@ pub fn generate_live<P: AsRef<Path>>(
.map(parse_query_module)
.collect::<Result<_, parser::error::Error>>()?;
// Generate
let prepared_modules = prepare(client, modules)?;
let prepared_modules = prepare(client, modules, settings.clone())?;
let generated_code = generate_internal(prepared_modules, settings);
// Write
if let Some(d) = destination {
Expand Down Expand Up @@ -86,7 +89,7 @@ pub fn generate_managed<P: AsRef<Path>>(
container::setup(podman)?;
let mut client = conn::cornucopia_conn()?;
load_schema(&mut client, schema_files)?;
let prepared_modules = prepare(&mut client, modules)?;
let prepared_modules = prepare(&mut client, modules, settings.clone())?;
let generated_code = generate_internal(prepared_modules, settings);
container::cleanup(podman)?;

Expand Down
33 changes: 17 additions & 16 deletions crates/cornucopia/src/prepare_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ use crate::{
codegen::GenCtx,
parser::{Module, NullableIdent, Query, Span, TypeAnnotation},
read_queries::ModuleInfo,
type_registrar::CornucopiaType,
type_registrar::TypeRegistrar,
type_registrar::{CornucopiaType, TypeRegistrar},
utils::KEYWORD,
validation,
validation, CodegenSettings,
};

use self::error::Error;
Expand Down Expand Up @@ -226,8 +225,12 @@ impl PreparedModule {
}

/// Prepares all modules
pub(crate) fn prepare(client: &mut Client, modules: Vec<Module>) -> Result<Preparation, Error> {
let mut registrar = TypeRegistrar::default();
pub(crate) fn prepare(
client: &mut Client,
modules: Vec<Module>,
settings: CodegenSettings,
) -> Result<Preparation, Error> {
let mut registrar = TypeRegistrar::new(settings.config.custom_type_map);
let mut tmp = Preparation {
modules: Vec::new(),
types: IndexMap::new(),
Expand All @@ -244,16 +247,12 @@ pub(crate) fn prepare(client: &mut Client, modules: Vec<Module>) -> Result<Prepa
}

// Prepare types grouped by schema
for ((schema, name), ty) in &registrar.types {
if let Some(ty) = prepare_type(&registrar, name, ty, &declared) {
match tmp.types.entry(schema.clone()) {
Entry::Occupied(mut entry) => {
entry.get_mut().push(ty);
}
Entry::Vacant(entry) => {
entry.insert(vec![ty]);
}
}
for (schema_key, ty) in registrar.types() {
if let Some(ty) = prepare_type(&registrar, schema_key.name, ty, &declared) {
tmp.types
.entry(schema_key.schema.to_owned())
.or_default()
.push(ty);
}
}
Ok(tmp)
Expand Down Expand Up @@ -301,7 +300,9 @@ fn prepare_type(
})
.collect(),
),
_ => unreachable!(),
_ => {
return None;
}
};
Some(PreparedType {
name: name.to_string(),
Expand Down
Loading