diff --git a/pyhive/presto.py b/pyhive/presto.py index a38cd891..99f96bf2 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -18,6 +18,7 @@ import datetime import logging import requests +import datetime from requests.auth import HTTPBasicAuth import os @@ -41,6 +42,20 @@ def escape_datetime(self, item, format): formatted = super(PrestoParamEscaper, self).escape_datetime(item, format, 3) return "{} {}".format(_type, formatted) + def escape_item(self, item): + if isinstance(item, datetime.datetime): + return self.escape_datetime(item) + elif isinstance(item, datetime.date): + return self.escape_date(item) + else: + return super(PrestoParamEscaper, self).escape_item(item) + + def escape_date(self, item): + return "date '{}'".format(item) + + def escape_datetime(self, item): + return "timestamp '{}'".format(item) + _escaper = PrestoParamEscaper() diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index a199ebe1..1830b932 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -12,11 +12,13 @@ from sqlalchemy import exc from sqlalchemy import types from sqlalchemy import util + # TODO shouldn't use mysql type from sqlalchemy.databases import mysql from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.expression import Alias from pyhive import presto from pyhive.common import UniversalSet @@ -46,6 +48,37 @@ class PrestoCompiler(SQLCompiler): def visit_char_length_func(self, fn, **kw): return 'length{}'.format(self.function_argspec(fn, **kw)) + def visit_column(self, column, add_to_result_map=None, include_table=True, **kwargs): + sql = super(PrestoCompiler, self).visit_column( + column, add_to_result_map, include_table, **kwargs + ) + table = column.table + return self.__add_catalog(sql, table) + + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, + fromhints=None, use_schema=True, **kwargs): + sql = super(PrestoCompiler, self).visit_table( + table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs + ) + return self.__add_catalog(sql, table) + + def __add_catalog(self, sql, table): + if table is None: + return sql + + if isinstance(table, Alias): + return sql + + if ( + "presto" not in table.dialect_options + or "catalog" not in table.dialect_options["presto"]._non_defaults + ): + return sql + + catalog = table.dialect_options["presto"]._non_defaults["catalog"] + sql = "\"{catalog}\".{sql}".format(catalog=catalog, sql=sql) + return sql + class PrestoTypeCompiler(compiler.GenericTypeCompiler): def visit_CLOB(self, type_, **kw):