Skip to content

Enable use of generators without engine #323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ test = [
"sqlacodegen[sqlmodel]",
"pytest >= 7.4",
"coverage >= 7",
"psycopg2-binary",
"mysql-connector-python",
]
sqlmodel = ["sqlmodel >= 0.0.22"]
citext = ["sqlalchemy-citext >= 1.7.0"]
Expand Down
2 changes: 1 addition & 1 deletion src/sqlacodegen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def main() -> None:

# Instantiate the generator
generator_class = generators[args.generator].load()
generator = generator_class(metadata, engine, options)
generator = generator_class(metadata, engine.dialect, options)

if not generator.views_supported:
name = generator_class.__name__
Expand Down
34 changes: 16 additions & 18 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Computed,
Constraint,
DefaultClause,
Dialect,
Enum,
Float,
ForeignKey,
Expand All @@ -39,7 +40,6 @@
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import CompileError
from sqlalchemy.sql.elements import TextClause

Expand Down Expand Up @@ -94,11 +94,9 @@ class Base:
class CodeGenerator(metaclass=ABCMeta):
valid_options: ClassVar[set[str]] = set()

def __init__(
self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
):
def __init__(self, metadata: MetaData, dialect: Dialect, options: Sequence[str]):
self.metadata: MetaData = metadata
self.bind: Connection | Engine = bind
self.dialect: Dialect = dialect
self.options: set[str] = set(options)

# Validate options
Expand Down Expand Up @@ -129,12 +127,12 @@ class TablesGenerator(CodeGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
dialect: Dialect,
options: Sequence[str],
*,
indentation: str = " ",
):
super().__init__(metadata, bind, options)
super().__init__(metadata, dialect, options)
self.indentation: str = indentation
self.imports: dict[str, set[str]] = defaultdict(set)
self.module_imports: set[str] = set()
Expand Down Expand Up @@ -574,7 +572,7 @@ def add_fk_options(*opts: Any) -> None:
]
add_fk_options(local_columns, remote_columns)
elif isinstance(constraint, CheckConstraint):
args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
args.append(repr(get_compiled_expression(constraint.sqltext, self.dialect)))
elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
args.extend(repr(col.name) for col in constraint.columns)
else:
Expand Down Expand Up @@ -620,7 +618,7 @@ def fix_column_types(self, table: Table) -> None:
# Detect check constraints for boolean and enum columns
for constraint in table.constraints.copy():
if isinstance(constraint, CheckConstraint):
sqltext = get_compiled_expression(constraint.sqltext, self.bind)
sqltext = get_compiled_expression(constraint.sqltext, self.dialect)

# Turn any integer-like column with a CheckConstraint like
# "column IN (0, 1)" into a Boolean
Expand Down Expand Up @@ -658,7 +656,7 @@ def fix_column_types(self, table: Table) -> None:
pass

# PostgreSQL specific fix: detect sequences from server_default
if column.server_default and self.bind.dialect.name == "postgresql":
if column.server_default and self.dialect.name == "postgresql":
if isinstance(column.server_default, DefaultClause) and isinstance(
column.server_default.arg, TextClause
):
Expand All @@ -673,7 +671,7 @@ def fix_column_types(self, table: Table) -> None:
column.server_default = None

def get_adapted_type(self, coltype: Any) -> Any:
compiled_type = coltype.compile(self.bind.engine.dialect)
compiled_type = coltype.compile(self.dialect)
for supercls in coltype.__class__.__mro__:
if not supercls.__name__.startswith("_") and hasattr(
supercls, "__visit_name__"
Expand All @@ -699,7 +697,7 @@ def get_adapted_type(self, coltype: Any) -> Any:
try:
# If the adapted column type does not render the same as the
# original, don't substitute it
if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
if new_coltype.compile(self.dialect) != compiled_type:
# Make an exception to the rule for Float and arrays of Float,
# since at least on PostgreSQL, Float can accurately represent
# both REAL and DOUBLE_PRECISION
Expand Down Expand Up @@ -730,13 +728,13 @@ class DeclarativeGenerator(TablesGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
dialect: Dialect,
options: Sequence[str],
*,
indentation: str = " ",
base_class_name: str = "Base",
):
super().__init__(metadata, bind, options, indentation=indentation)
super().__init__(metadata, dialect, options, indentation=indentation)
self.base_class_name: str = base_class_name
self.inflect_engine = inflect.engine()

Expand Down Expand Up @@ -1317,7 +1315,7 @@ class DataclassGenerator(DeclarativeGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
dialect: Dialect,
options: Sequence[str],
*,
indentation: str = " ",
Expand All @@ -1327,7 +1325,7 @@ def __init__(
):
super().__init__(
metadata,
bind,
dialect,
options,
indentation=indentation,
base_class_name=base_class_name,
Expand All @@ -1353,15 +1351,15 @@ class SQLModelGenerator(DeclarativeGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
dialect: Dialect,
options: Sequence[str],
*,
indentation: str = " ",
base_class_name: str = "SQLModel",
):
super().__init__(
metadata,
bind,
dialect,
options,
indentation=indentation,
base_class_name=base_class_name,
Expand Down
9 changes: 5 additions & 4 deletions src/sqlacodegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from collections.abc import Mapping
from typing import Any

from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
from sqlalchemy.engine import Connection, Engine
from sqlalchemy import Dialect, PrimaryKeyConstraint, UniqueConstraint
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import (
Expand Down Expand Up @@ -34,9 +33,11 @@ def get_constraint_sort_key(constraint: Constraint) -> str:
return str(constraint)


def get_compiled_expression(statement: ClauseElement, bind: Engine | Connection) -> str:
def get_compiled_expression(statement: ClauseElement, dialect: Dialect) -> str:
"""Return the statement in a form where any placeholders have been filled in."""
return str(statement.compile(bind, compile_kwargs={"literal_binds": True}))
return str(
statement.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)


def get_common_fk_constraints(
Expand Down
17 changes: 9 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@

import pytest
from pytest import FixtureRequest
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy import Dialect
from sqlalchemy.dialects import mysql, postgresql, sqlite
from sqlalchemy.orm import clear_mappers, configure_mappers
from sqlalchemy.schema import MetaData


@pytest.fixture
def engine(request: FixtureRequest) -> Engine:
dialect = getattr(request, "param", None)
if dialect == "postgresql":
return create_engine("postgresql:///testdb")
elif dialect == "mysql":
return create_engine("mysql+mysqlconnector://testdb")
def dialect(request: FixtureRequest) -> Dialect:
dialect_name = getattr(request, "param", None)
if dialect_name == "postgresql":
return postgresql.dialect()
elif dialect_name == "mysql":
return mysql.mysqlconnector.dialect()
else:
return create_engine("sqlite:///:memory:")
return sqlite.dialect()


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tests/test_generator_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import Dialect
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine import Engine
from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table
from sqlalchemy.sql.expression import text
from sqlalchemy.types import INTEGER, VARCHAR
Expand All @@ -15,10 +15,10 @@

@pytest.fixture
def generator(
request: FixtureRequest, metadata: MetaData, engine: Engine
request: FixtureRequest, metadata: MetaData, dialect: Dialect
) -> CodeGenerator:
options = getattr(request, "param", [])
return DataclassGenerator(metadata, engine, options)
return DataclassGenerator(metadata, dialect, options)


def test_basic_class(generator: CodeGenerator) -> None:
Expand Down
7 changes: 3 additions & 4 deletions tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy.engine import Engine
from sqlalchemy import Dialect, PrimaryKeyConstraint
from sqlalchemy.schema import (
CheckConstraint,
Column,
Expand All @@ -24,10 +23,10 @@

@pytest.fixture
def generator(
request: FixtureRequest, metadata: MetaData, engine: Engine
request: FixtureRequest, metadata: MetaData, dialect: Dialect
) -> CodeGenerator:
options = getattr(request, "param", [])
return DeclarativeGenerator(metadata, engine, options)
return DeclarativeGenerator(metadata, dialect, options)


def test_indexes(generator: CodeGenerator) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_generator_sqlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy.engine import Engine
from sqlalchemy import Dialect
from sqlalchemy.schema import (
CheckConstraint,
Column,
Expand All @@ -21,10 +21,10 @@

@pytest.fixture
def generator(
request: FixtureRequest, metadata: MetaData, engine: Engine
request: FixtureRequest, metadata: MetaData, dialect: Dialect
) -> CodeGenerator:
options = getattr(request, "param", [])
return SQLModelGenerator(metadata, engine, options)
return SQLModelGenerator(metadata, dialect, options)


def test_indexes(generator: CodeGenerator) -> None:
Expand Down
Loading
Loading