diff --git a/benches/codegen.rs b/benches/codegen.rs index 4eabc30f..8bca128d 100644 --- a/benches/codegen.rs +++ b/benches/codegen.rs @@ -17,6 +17,7 @@ fn bench(c: &mut Criterion) { gen_sync: true, gen_async: false, derive_ser: true, + config: Default::default(), }, ) .unwrap() @@ -32,6 +33,7 @@ fn bench(c: &mut Criterion) { gen_sync: true, gen_async: false, derive_ser: true, + config: Default::default(), }, ) .unwrap() diff --git a/benches/execution/diesel_benches.rs b/benches/execution/diesel_benches.rs index 5c7407e1..b5f2ee66 100644 --- a/benches/execution/diesel_benches.rs +++ b/benches/execution/diesel_benches.rs @@ -164,10 +164,7 @@ pub fn bench_insert(b: &mut Bencher, conn: &mut PgConnection, size: usize) { }; let insert = &insert; - b.iter(|| { - let insert = insert; - insert(conn) - }) + b.iter(|| insert(conn)) } pub fn loading_associations_sequentially(b: &mut Bencher, conn: &mut PgConnection) { diff --git a/crates/cornucopia/Cargo.toml b/crates/cornucopia/Cargo.toml index 04773127..f7e21b0d 100644 --- a/crates/cornucopia/Cargo.toml +++ b/crates/cornucopia/Cargo.toml @@ -33,3 +33,7 @@ heck = "0.4.0" # Order-preserving map to work around borrowing issues indexmap = "2.0.2" + +# Config handling +serde = { version = "1.0.203", features = ["derive"] } +toml = "0.8.14" diff --git a/crates/cornucopia/src/cli.rs b/crates/cornucopia/src/cli.rs index 73692260..e85502f4 100644 --- a/crates/cornucopia/src/cli.rs +++ b/crates/cornucopia/src/cli.rs @@ -1,8 +1,12 @@ -use std::path::PathBuf; +use miette::Diagnostic; +use std::{fs, path::PathBuf}; +use thiserror::Error as ThisError; use clap::{Parser, Subcommand}; -use crate::{conn, container, error::Error, generate_live, generate_managed, CodegenSettings}; +use crate::{ + config::Config, conn, container, error::Error, generate_live, generate_managed, CodegenSettings, +}; /// Command line interface to interact with Cornucopia SQL. #[derive(Parser, Debug)] @@ -28,6 +32,13 @@ struct Args { /// Derive serde's `Serialize` trait for generated types. #[clap(long)] serialize: bool, + /// The location of the configuration file. + #[clap(short, long, default_value = default_config_path())] + config: PathBuf, +} + +const fn default_config_path() -> &'static str { + "cornucopia.toml" } #[derive(Debug, Subcommand)] @@ -44,8 +55,26 @@ enum Action { }, } +/// Enumeration of the errors reported by the CLI. +#[derive(ThisError, Debug, Diagnostic)] +pub enum CliError { + /// An error occurred while loading the configuration file. + #[error("Could not load config `{path}`: ({err})")] + MissingConfig { path: String, err: std::io::Error }, + /// An error occurred while parsing the configuration file. + #[error("Could not parse config `{path}`: ({err})")] + ConfigContents { + path: String, + err: Box, + }, + /// An error occurred while running the CLI. + #[error(transparent)] + #[diagnostic(transparent)] + Internal(#[from] Error), +} + // Main entrypoint of the CLI. Parses the args and calls the appropriate routines. -pub fn run() -> Result<(), Error> { +pub fn run() -> Result<(), CliError> { let Args { podman, queries_path, @@ -54,17 +83,40 @@ pub fn run() -> Result<(), Error> { sync, r#async, serialize, + config, } = Args::parse(); + let config = match fs::read_to_string(config.as_path()) { + Ok(contents) => match toml::from_str(&contents) { + Ok(config) => config, + Err(err) => { + return Err(CliError::ConfigContents { + path: config.to_string_lossy().into_owned(), + err: err.into(), + }); + } + }, + Err(err) => { + if config.as_path().as_os_str() != default_config_path() { + return Err(CliError::MissingConfig { + path: config.to_string_lossy().into_owned(), + err, + }); + } else { + Config::default() + } + } + }; let settings = CodegenSettings { gen_async: r#async || !sync, gen_sync: sync, derive_ser: serialize, + config, }; match action { Action::Live { url } => { - let mut client = conn::from_url(&url)?; + let mut client = conn::from_url(&url).map_err(|e| CliError::Internal(e.into()))?; generate_live(&mut client, &queries_path, Some(&destination), settings)?; } Action::Schema { schema_files } => { @@ -77,7 +129,7 @@ pub fn run() -> Result<(), Error> { settings, ) { container::cleanup(podman).ok(); - return Err(e); + return Err(CliError::Internal(e)); } } }; diff --git a/crates/cornucopia/src/config.rs b/crates/cornucopia/src/config.rs new file mode 100644 index 00000000..92419e0f --- /dev/null +++ b/crates/cornucopia/src/config.rs @@ -0,0 +1,12 @@ +//! Configuration for Cornucopia. + +use std::collections::HashMap; + +use serde::Deserialize; + +/// Configuration for Cornucopia. +#[derive(Clone, Deserialize, Default, Debug)] +pub struct Config { + /// Contains a map of what given type should map to. + pub custom_type_map: HashMap, +} diff --git a/crates/cornucopia/src/lib.rs b/crates/cornucopia/src/lib.rs index b879ee28..7da88f9d 100644 --- a/crates/cornucopia/src/lib.rs +++ b/crates/cornucopia/src/lib.rs @@ -1,5 +1,6 @@ mod cli; mod codegen; +mod config; mod error; mod load_schema; mod parser; @@ -16,6 +17,7 @@ pub mod container; use std::path::Path; +use config::Config; use postgres::Client; use codegen::generate as generate_internal; @@ -31,11 +33,12 @@ pub use error::Error; pub use load_schema::load_schema; /// Struct containing the settings for code generation. -#[derive(Clone, Copy)] +#[derive(Clone)] pub struct CodegenSettings { pub gen_async: bool, pub gen_sync: bool, pub derive_ser: bool, + pub config: Config, } /// Generates Rust queries from PostgreSQL queries located at `queries_path`, @@ -54,7 +57,7 @@ pub fn generate_live>( .map(parse_query_module) .collect::>()?; // Generate - let prepared_modules = prepare(client, modules)?; + let prepared_modules = prepare(client, modules, settings.clone())?; let generated_code = generate_internal(prepared_modules, settings); // Write if let Some(d) = destination { @@ -86,7 +89,7 @@ pub fn generate_managed>( container::setup(podman)?; let mut client = conn::cornucopia_conn()?; load_schema(&mut client, schema_files)?; - let prepared_modules = prepare(&mut client, modules)?; + let prepared_modules = prepare(&mut client, modules, settings.clone())?; let generated_code = generate_internal(prepared_modules, settings); container::cleanup(podman)?; diff --git a/crates/cornucopia/src/prepare_queries.rs b/crates/cornucopia/src/prepare_queries.rs index 9edf6eee..bd2afbff 100644 --- a/crates/cornucopia/src/prepare_queries.rs +++ b/crates/cornucopia/src/prepare_queries.rs @@ -9,10 +9,9 @@ use crate::{ codegen::GenCtx, parser::{Module, NullableIdent, Query, Span, TypeAnnotation}, read_queries::ModuleInfo, - type_registrar::CornucopiaType, - type_registrar::TypeRegistrar, + type_registrar::{CornucopiaType, TypeRegistrar}, utils::KEYWORD, - validation, + validation, CodegenSettings, }; use self::error::Error; @@ -226,8 +225,12 @@ impl PreparedModule { } /// Prepares all modules -pub(crate) fn prepare(client: &mut Client, modules: Vec) -> Result { - let mut registrar = TypeRegistrar::default(); +pub(crate) fn prepare( + client: &mut Client, + modules: Vec, + settings: CodegenSettings, +) -> Result { + let mut registrar = TypeRegistrar::new(settings.config.custom_type_map); let mut tmp = Preparation { modules: Vec::new(), types: IndexMap::new(), @@ -244,16 +247,12 @@ pub(crate) fn prepare(client: &mut Client, modules: Vec) -> Result { - entry.get_mut().push(ty); - } - Entry::Vacant(entry) => { - entry.insert(vec![ty]); - } - } + for (schema_key, ty) in registrar.types() { + if let Some(ty) = prepare_type(®istrar, schema_key.name, ty, &declared) { + tmp.types + .entry(schema_key.schema.to_owned()) + .or_default() + .push(ty); } } Ok(tmp) @@ -301,7 +300,9 @@ fn prepare_type( }) .collect(), ), - _ => unreachable!(), + _ => { + return None; + } }; Some(PreparedType { name: name.to_string(), diff --git a/crates/cornucopia/src/type_registrar.rs b/crates/cornucopia/src/type_registrar.rs index 1b7ca70f..aefc9cb7 100644 --- a/crates/cornucopia/src/type_registrar.rs +++ b/crates/cornucopia/src/type_registrar.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::{collections::HashMap, rc::Rc}; use heck::ToUpperCamelCase; use indexmap::{map::Entry, IndexMap}; @@ -18,7 +18,7 @@ use self::error::Error; pub(crate) enum CornucopiaType { Simple { pg_ty: Type, - rust_name: &'static str, + rust_name: String, is_copy: bool, }, Array { @@ -33,6 +33,7 @@ pub(crate) enum CornucopiaType { struct_name: String, is_copy: bool, is_params: bool, + is_mapped: bool, }, } @@ -168,6 +169,11 @@ impl CornucopiaType { } } CornucopiaType::Domain { inner, .. } => inner.own_ty(false, ctx), + CornucopiaType::Custom { + is_mapped, + struct_name, + .. + } if *is_mapped => struct_name.to_string(), CornucopiaType::Custom { struct_name, pg_ty, .. } => custom_ty_path(pg_ty.schema(), struct_name, ctx), @@ -287,9 +293,14 @@ impl CornucopiaType { is_copy, pg_ty, struct_name, + is_mapped, .. } => { - let path = custom_ty_path(pg_ty.schema(), struct_name, ctx); + let path = if *is_mapped { + struct_name.to_string() + } else { + custom_ty_path(pg_ty.schema(), struct_name, ctx) + }; if *is_copy { path } else { @@ -311,12 +322,27 @@ pub fn custom_ty_path(schema: &str, struct_name: &str, ctx: &GenCtx) -> String { } /// Data structure holding all types known to this particular run of Cornucopia. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub(crate) struct TypeRegistrar { - pub types: IndexMap<(String, String), Rc>, + types: IndexMap<(String, String), Rc>, + type_mappings: HashMap, } impl TypeRegistrar { + /// Create a new type registrar using the specified type mappings. + pub(crate) fn new(type_mappings: HashMap) -> Self { + Self { + types: IndexMap::new(), + type_mappings, + } + } + + pub(crate) fn types(&self) -> impl Iterator { + self.types + .iter() + .map(|((schema, name), ty)| (SchemaKey::new(schema, name), ty.as_ref())) + } + pub(crate) fn register( &mut self, name: &str, @@ -331,6 +357,7 @@ impl TypeRegistrar { struct_name: rust_ty_name, is_copy, is_params, + is_mapped: false, } } @@ -345,6 +372,23 @@ impl TypeRegistrar { return Ok(&self.types[idx]); } + if let Some(mapped_type) = self.type_mappings.get(ty.name()).cloned() { + if matches!(mapped_type.as_str(), "String" | "str") { + return Ok(self.insert(ty, move || CornucopiaType::Simple { + pg_ty: Type::VARCHAR, + rust_name: "String".to_string(), + is_copy: false, + })); + } + return Ok(self.insert(ty, move || CornucopiaType::Custom { + pg_ty: ty.clone(), + is_copy: false, + struct_name: mapped_type, + is_params: true, + is_mapped: true, + })); + } + Ok(match ty.kind() { Kind::Enum(_) => self.insert(ty, || custom(ty, true, true)), Kind::Array(inner_ty) => { @@ -397,12 +441,12 @@ impl TypeRegistrar { query: query_name.span, col_name: name.to_string(), col_ty: ty.to_string(), - }) + }); } }; self.insert(ty, || CornucopiaType::Simple { pg_ty: ty.clone(), - rust_name, + rust_name: rust_name.to_owned(), is_copy, }) } @@ -412,7 +456,7 @@ impl TypeRegistrar { query: query_name.span, col_name: name.to_string(), col_ty: ty.to_string(), - }) + }); } }) } @@ -424,7 +468,7 @@ impl TypeRegistrar { .clone() } - fn insert(&mut self, ty: &Type, call: impl Fn() -> CornucopiaType) -> &Rc { + fn insert(&mut self, ty: &Type, call: impl FnOnce() -> CornucopiaType) -> &Rc { let index = match self .types .entry((ty.schema().to_owned(), ty.name().to_owned())) diff --git a/crates/cornucopia/src/utils.rs b/crates/cornucopia/src/utils.rs index e9748da3..87442abb 100644 --- a/crates/cornucopia/src/utils.rs +++ b/crates/cornucopia/src/utils.rs @@ -5,8 +5,17 @@ use postgres_types::Type; /// Allows us to query a map using type schema as key without having to own the key strings #[derive(PartialEq, Eq, Hash)] pub struct SchemaKey<'a> { - schema: &'a str, - name: &'a str, + /// The schema of this type. + pub schema: &'a str, + /// The name of this type. + pub name: &'a str, +} + +impl<'a> SchemaKey<'a> { + /// Creates a new [`SchemaKey`] from the specified components. + pub fn new(schema: &'a str, name: &'a str) -> Self { + SchemaKey { schema, name } + } } impl<'a> From<&'a Type> for SchemaKey<'a> { diff --git a/test_integration/src/fixtures.rs b/test_integration/src/fixtures.rs index 308c81c2..7769e0d0 100644 --- a/test_integration/src/fixtures.rs +++ b/test_integration/src/fixtures.rs @@ -77,6 +77,7 @@ impl From<&CodegenTest> for CodegenSettings { gen_async: codegen_test.r#async || !codegen_test.sync, gen_sync: codegen_test.sync, derive_ser: codegen_test.derive_ser, + config: Default::default(), } } } @@ -96,6 +97,7 @@ impl From<&ErrorTest> for CodegenSettings { derive_ser: false, gen_async: false, gen_sync: true, + config: Default::default(), } } }