diff --git a/cratedb_mcp/util/sql.py b/cratedb_mcp/util/sql.py index be7092a..c0c11da 100644 --- a/cratedb_mcp/util/sql.py +++ b/cratedb_mcp/util/sql.py @@ -2,6 +2,7 @@ import logging import typing as t +import cratedb_sqlparse import sqlparse from sqlparse.tokens import Keyword @@ -40,6 +41,7 @@ class SqlStatementClassifier: expression: str permit_all: bool = False + _parsed_cratedb: t.Any = dataclasses.field(init=False, default=None) _parsed_sqlparse: t.Any = dataclasses.field(init=False, default=None) def __post_init__(self) -> None: @@ -48,6 +50,14 @@ def __post_init__(self) -> None: if self.expression: self.expression = self.expression.strip() + def parse_cratedb(self): + """ + Parse expression using `cratedb-sqlparse` library. + """ + if self._parsed_cratedb is None: + self._parsed_cratedb = cratedb_sqlparse.sqlparse(self.expression) + return self._parsed_cratedb + def parse_sqlparse(self) -> t.List[sqlparse.sql.Statement]: """ Parse expression using traditional `sqlparse` library. @@ -85,8 +95,8 @@ def operation(self) -> str: """ The SQL operation: SELECT, INSERT, UPDATE, DELETE, CREATE, etc. """ - parsed = self.parse_sqlparse() - return parsed[0].get_type().upper() + parsed = self.parse_cratedb() + return parsed[0].type.upper() @property def is_camouflage(self) -> bool: diff --git a/pyproject.toml b/pyproject.toml index 667f96c..984bd07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "attrs", "cachetools<6", "cratedb-about==0.0.4", + "cratedb-sqlparse==0.0.14", "hishel<0.2", "mcp[cli]>=1.5.0", "sqlparse<0.6",