Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/sqlfmt/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Union
from typing import List, Optional, Union

import click

Expand Down Expand Up @@ -152,14 +152,29 @@
"case sensitivity in function, field, and alias names."
),
)
@click.option(
"--config",
"config_path",
envvar="SQLFMT_CONFIG",
type=click.Path(
exists=True, dir_okay=False, allow_dash=False, resolve_path=True, path_type=Path
),
help=(
"A path to a `pyproject.toml` file. Options passed at the command line will "
"override settings in this file."
),
)
@click.argument(
"files",
nargs=-1,
type=click.Path(exists=True, allow_dash=True, resolve_path=True, path_type=Path),
)
@click.pass_context
def sqlfmt(
ctx: click.Context, files: List[Path], **kwargs: Union[bool, int, List[str], str]
ctx: click.Context,
files: List[Path],
config_path: Optional[Path] = None,
**kwargs: Union[bool, int, List[str], str],
) -> None:
"""
sqlfmt formats your dbt SQL files so you don't have to.
Expand All @@ -173,7 +188,7 @@ def sqlfmt(
https://sqlfmt.com for documentation and more information.
"""
if files:
config = load_config_file(files)
config = load_config_file(files, config_path)
non_default_options = {
k: v
for k, v in kwargs.items()
Expand Down
11 changes: 8 additions & 3 deletions src/sqlfmt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
Config = Dict[str, Union[bool, int, List[str], str, Path]]


def load_config_file(files: List[Path]) -> Config:
def load_config_file(files: List[Path], config_path: Optional[Path]) -> Config:
"""
files is a list of resolved, absolute paths (like the ones passed from the
Click CLI). This finds a pyproject.toml file in the common parent directory
of files (or in the common parent's parents).

If config_path is provided, searching for a config via files will be skipped
entirely.
"""
common_parents = _get_common_parents(files)
config_path = _find_config_file(common_parents)
if config_path is None:
common_parents = _get_common_parents(files)
config_path = _find_config_file(common_parents)

config = _load_config_from_path(config_path)
return config

Expand Down
64 changes: 64 additions & 0 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,67 @@ def test_preformatted_inherit_encoding(
assert results.stderr.startswith("1 file had errors")
assert "006_has_bom.sql" in results.stderr
assert "Could not parse SQL at position 1" in results.stderr


def test_config_option(sqlfmt_runner: CliRunner, preformatted_dir: Path) -> None:
copy_config_file_to_dst("valid_sqlfmt_config.toml", preformatted_dir)
args = (
f"{preformatted_dir.as_posix()} "
f"--config {(preformatted_dir / 'pyproject.toml').as_posix()} "
"--check"
)
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
# 3 files should fail formatting with longer line length in config
assert results.exit_code == 1
assert results.stderr.startswith("3 files failed formatting check")

args = f"{preformatted_dir.as_posix()} --check"
results = sqlfmt_runner.invoke(
sqlfmt_main,
args=args,
env={"SQLFMT_CONFIG": f"{preformatted_dir.as_posix()}/pyproject.toml"},
)
assert results.exit_code == 1
assert results.stderr.startswith("3 files failed formatting check")

# supply CLI args to override config file so checks pass
args = (
f"{preformatted_dir.as_posix()} "
f"--config {(preformatted_dir / 'pyproject.toml').as_posix()} "
"--line-length 88 --check"
)
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 0

# supply CLI args to override config file so checks pass
args = f"{preformatted_dir.as_posix()} --line-length 88 --check"
results = sqlfmt_runner.invoke(
sqlfmt_main,
args=args,
env={"SQLFMT_CONFIG": f"{preformatted_dir.as_posix()}/pyproject.toml"},
)
assert results.exit_code == 0


def test_config_does_not_exist(
sqlfmt_runner: CliRunner, preformatted_dir: Path
) -> None:
# make sure sqlfmt fails fast if the passed config doesn't exist
args = (
f"--config {preformatted_dir.as_posix()}/does_not_exist.toml "
f"{preformatted_dir.as_posix()}"
)
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 2
assert "Error: Invalid value for '--config'" in results.stderr
assert "does not exist" in results.stderr

args = f"{preformatted_dir.as_posix()}"
results = sqlfmt_runner.invoke(
sqlfmt_main,
args=args,
env={"SQLFMT_CONFIG": f"{preformatted_dir.as_posix()}/does_not_exist.toml"},
)
assert results.exit_code == 2
assert "Error: Invalid value for '--config'" in results.stderr
assert "does not exist" in results.stderr
Loading