diff --git a/velox/expression/CastExpr.cpp b/velox/expression/CastExpr.cpp index a954ea72fba..1fa4dd7e14c 100644 --- a/velox/expression/CastExpr.cpp +++ b/velox/expression/CastExpr.cpp @@ -26,6 +26,7 @@ #include "velox/expression/ScopedVarSetter.h" #include "velox/external/tzdb/time_zone.h" #include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/type/CastRegistry.h" #include "velox/type/Type.h" #include "velox/type/tz/TimeZoneMap.h" #include "velox/vector/ComplexVector.h" @@ -782,13 +783,12 @@ void CastExpr::applyPeeled( const TypePtr& toType, VectorPtr& result) { auto castFromOperator = getCastOperator(fromType); - if (castFromOperator && !castFromOperator->isSupportedToType(toType)) { - VELOX_USER_FAIL( - "Cannot cast {} to {}.", fromType->toString(), toType->toString()); - } - auto castToOperator = getCastOperator(toType); - if (castToOperator && !castToOperator->isSupportedFromType(fromType)) { + + // CastRulesRegistry is the source of truth for all custom type cast + // validation, including container types (e.g., ARRAY → JSON). + if ((castFromOperator || castToOperator) && + !CastRulesRegistry::instance().canCast(fromType, toType)) { VELOX_USER_FAIL( "Cannot cast {} to {}.", fromType->toString(), toType->toString()); } diff --git a/velox/expression/CastExpr.h b/velox/expression/CastExpr.h index 09c1c03baff..8743645313e 100644 --- a/velox/expression/CastExpr.h +++ b/velox/expression/CastExpr.h @@ -28,13 +28,19 @@ class CastOperator { public: virtual ~CastOperator() = default; - /// Determines whether the cast operator supports casting to the custom type - /// from the other type. - virtual bool isSupportedFromType(const TypePtr& other) const = 0; + /// Deprecated: cast validation is now handled entirely by + /// CastRulesRegistry. This method is no longer called by CastExpr. + /// Register cast rules via registerCastRules() instead of overriding. + virtual bool isSupportedFromType(const TypePtr&) const { + return true; + } - /// Determines whether the cast operator supports casting from the custom type - /// to the other type. - virtual bool isSupportedToType(const TypePtr& other) const = 0; + /// Deprecated: cast validation is now handled entirely by + /// CastRulesRegistry. This method is no longer called by CastExpr. + /// Register cast rules via registerCastRules() instead of overriding. + virtual bool isSupportedToType(const TypePtr&) const { + return true; + } /// Casts an input vector to the custom type. This function should not throw /// when processing input rows, but report errors via context.setError(). diff --git a/velox/expression/SignatureBinder.cpp b/velox/expression/SignatureBinder.cpp index 9ea71fe2e73..b44afc659ed 100644 --- a/velox/expression/SignatureBinder.cpp +++ b/velox/expression/SignatureBinder.cpp @@ -410,7 +410,7 @@ bool SignatureBinderBase::tryBind( const auto& baseName = typeSignature.baseName(); auto typeName = boost::algorithm::to_upper_copy(baseName); if (!boost::algorithm::iequals(typeName, actualType->name())) { - if (allowCoercion) { + if (allowCoercion && typeSignature.parameters().empty()) { if (auto availableCoercion = TypeCoercer::coerceTypeBase(actualType, typeName)) { coercion = availableCoercion.value(); diff --git a/velox/functions/prestosql/types/BigintEnumRegistration.cpp b/velox/functions/prestosql/types/BigintEnumRegistration.cpp index 9cf15122631..c0ee98b36dd 100644 --- a/velox/functions/prestosql/types/BigintEnumRegistration.cpp +++ b/velox/functions/prestosql/types/BigintEnumRegistration.cpp @@ -17,6 +17,7 @@ #include "velox/functions/prestosql/types/BigintEnumRegistration.h" #include "velox/expression/CastExpr.h" #include "velox/functions/prestosql/types/BigintEnumType.h" +#include "velox/type/CastRegistry.h" namespace facebook::velox { namespace { @@ -29,17 +30,6 @@ class BigintEnumCastOperator : public exec::CastOperator { return kInstance; } - // Casting is supported from all integer types. - bool isSupportedFromType(const TypePtr& other) const override { - return BIGINT()->equivalent(*other) || TINYINT()->equivalent(*other) || - SMALLINT()->equivalent(*other) || INTEGER()->equivalent(*other); - } - - // Casting is only supported to BIGINT type. - bool isSupportedToType(const TypePtr& other) const override { - return BIGINT()->equivalent(*other); - } - void castTo( const BaseVector& input, exec::EvalCtx& context, @@ -133,5 +123,27 @@ class BigintEnumTypeFactory : public CustomTypeFactory { void registerBigintEnumType() { registerCustomType( "bigint_enum", std::make_unique()); + registerCastRules({ + {.fromType = "TINYINT", + .toType = "BIGINT_ENUM", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "SMALLINT", + .toType = "BIGINT_ENUM", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "INTEGER", + .toType = "BIGINT_ENUM", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "BIGINT", + .toType = "BIGINT_ENUM", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "BIGINT_ENUM", + .toType = "BIGINT", + .implicitAllowed = false, + .validator = {}}, + }); } } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/BingTileRegistration.cpp b/velox/functions/prestosql/types/BingTileRegistration.cpp index e5199935633..7a2f5753bef 100644 --- a/velox/functions/prestosql/types/BingTileRegistration.cpp +++ b/velox/functions/prestosql/types/BingTileRegistration.cpp @@ -19,6 +19,7 @@ #include "velox/common/fuzzer/ConstrainedGenerators.h" #include "velox/expression/CastExpr.h" #include "velox/functions/prestosql/types/BingTileType.h" +#include "velox/type/CastRegistry.h" namespace facebook::velox { @@ -33,24 +34,6 @@ class BingTileCastOperator final : public exec::CastOperator { return {std::shared_ptr{}, &kInstance}; } - bool isSupportedFromType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::BIGINT: - return true; - default: - return false; - } - } - - bool isSupportedToType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::BIGINT: - return true; - default: - return false; - } - } - void castTo( const BaseVector& input, exec::EvalCtx& context, @@ -138,6 +121,16 @@ class BingTileTypeFactory : public CustomTypeFactory { void registerBingTileType() { registerCustomType("bingtile", std::make_unique()); + registerCastRules({ + {.fromType = "BIGINT", + .toType = "BINGTILE", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "BINGTILE", + .toType = "BIGINT", + .implicitAllowed = false, + .validator = {}}, + }); } } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/IPAddressRegistration.cpp b/velox/functions/prestosql/types/IPAddressRegistration.cpp index 262e76ab630..ed58baad120 100644 --- a/velox/functions/prestosql/types/IPAddressRegistration.cpp +++ b/velox/functions/prestosql/types/IPAddressRegistration.cpp @@ -22,41 +22,12 @@ #include "velox/expression/CastExpr.h" #include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/IPPrefixType.h" +#include "velox/type/CastRegistry.h" namespace facebook::velox { namespace { class IPAddressCastOperator : public exec::CastOperator { public: - bool isSupportedFromType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::VARBINARY: - case TypeKind::VARCHAR: - return true; - case TypeKind::ROW: - if (isIPPrefixType(other)) { - return true; - } - [[fallthrough]]; - default: - return false; - } - } - - bool isSupportedToType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::VARBINARY: - case TypeKind::VARCHAR: - return true; - case TypeKind::ROW: - if (isIPPrefixType(other)) { - return true; - } - [[fallthrough]]; - default: - return false; - } - } - void castTo( const BaseVector& input, exec::EvalCtx& context, @@ -274,5 +245,31 @@ class IPAddressTypeFactory : public CustomTypeFactory { void registerIPAddressType() { registerCustomType( "ipaddress", std::make_unique()); + registerCastRules({ + {.fromType = "VARCHAR", + .toType = "IPADDRESS", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "VARBINARY", + .toType = "IPADDRESS", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "IPPREFIX", + .toType = "IPADDRESS", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "IPADDRESS", + .toType = "VARCHAR", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "IPADDRESS", + .toType = "VARBINARY", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "IPADDRESS", + .toType = "IPPREFIX", + .implicitAllowed = false, + .validator = {}}, + }); } } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/IPPrefixRegistration.cpp b/velox/functions/prestosql/types/IPPrefixRegistration.cpp index 387566a6171..b8f593ba52e 100644 --- a/velox/functions/prestosql/types/IPPrefixRegistration.cpp +++ b/velox/functions/prestosql/types/IPPrefixRegistration.cpp @@ -21,33 +21,12 @@ #include "velox/common/fuzzer/ConstrainedGenerators.h" #include "velox/expression/CastExpr.h" #include "velox/functions/prestosql/types/IPPrefixType.h" +#include "velox/type/CastRegistry.h" namespace facebook::velox { namespace { class IPPrefixCastOperator : public exec::CastOperator { public: - bool isSupportedFromType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::VARCHAR: - return true; - case TypeKind::HUGEINT: - return isIPAddressType(other); - default: - return false; - } - } - - bool isSupportedToType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::VARCHAR: - return true; - case TypeKind::HUGEINT: - return isIPAddressType(other); - default: - return false; - } - } - void castTo( const BaseVector& input, exec::EvalCtx& context, @@ -192,5 +171,23 @@ class IPPrefixTypeFactory : public CustomTypeFactory { void registerIPPrefixType() { registerCustomType("ipprefix", std::make_unique()); + registerCastRules({ + {.fromType = "VARCHAR", + .toType = "IPPREFIX", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "IPADDRESS", + .toType = "IPPREFIX", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "IPPREFIX", + .toType = "VARCHAR", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "IPPREFIX", + .toType = "IPADDRESS", + .implicitAllowed = false, + .validator = {}}, + }); } } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/JsonCastOperator.cpp b/velox/functions/prestosql/types/JsonCastOperator.cpp index f6193ed1275..03c3bd4a92c 100644 --- a/velox/functions/prestosql/types/JsonCastOperator.cpp +++ b/velox/functions/prestosql/types/JsonCastOperator.cpp @@ -1169,58 +1169,8 @@ simdjson::error_code castFromJsonOneRow( return simdjson::SUCCESS; } -bool isSupportedBasicType(const TypePtr& type) { - switch (type->kind()) { - case TypeKind::BOOLEAN: - case TypeKind::BIGINT: - case TypeKind::INTEGER: - case TypeKind::SMALLINT: - case TypeKind::TINYINT: - case TypeKind::DOUBLE: - case TypeKind::REAL: - case TypeKind::VARCHAR: - return true; - default: - return false; - } -} } // namespace -bool JsonCastOperator::isSupportedFromType(const TypePtr& other) const { - if (isSupportedBasicType(other)) { - return true; - } - - switch (other->kind()) { - case TypeKind::UNKNOWN: - case TypeKind::TIMESTAMP: - return true; - case TypeKind::ARRAY: - return isSupportedFromType(other->childAt(0)); - case TypeKind::ROW: - for (const auto& child : other->as().children()) { - if (!isSupportedFromType(child)) { - return false; - } - } - return true; - case TypeKind::MAP: - if (other->childAt(1)->isUnknown()) { - if (other->childAt(0)->isUnknown()) { - return true; - } - return isSupportedBasicType(other->childAt(0)) && - !isJsonType(other->childAt(0)); - } - - return ( - isSupportedBasicType(other->childAt(0)) && - isSupportedFromType(other->childAt(1))); - default: - return false; - } -} - template void JsonCastOperator::castFromJson( const BaseVector& input, @@ -1264,35 +1214,6 @@ void JsonCastOperator::castFromJson( writer.finish(); } -bool JsonCastOperator::isSupportedToType(const TypePtr& other) const { - if (other->isDate()) { - return false; - } - - if (isSupportedBasicType(other)) { - return true; - } - - switch (other->kind()) { - case TypeKind::ARRAY: - return isSupportedToType(other->childAt(0)); - case TypeKind::ROW: - for (const auto& child : other->as().children()) { - if (!isSupportedToType(child)) { - return false; - } - } - return true; - case TypeKind::MAP: - return ( - isSupportedBasicType(other->childAt(0)) && - isSupportedToType(other->childAt(1)) && - !isJsonType(other->childAt(0))); - default: - return false; - } -} - /// Converts an input vector of a supported type to Json type. The /// implementation follows the structure below. /// JsonOperator::castTo: type dispatch for castToJson diff --git a/velox/functions/prestosql/types/JsonCastOperator.h b/velox/functions/prestosql/types/JsonCastOperator.h index 047a4d0bca7..22ef9aab7b4 100644 --- a/velox/functions/prestosql/types/JsonCastOperator.h +++ b/velox/functions/prestosql/types/JsonCastOperator.h @@ -23,10 +23,6 @@ namespace facebook::velox { /// Custom operator for casts from and to Json type. class JsonCastOperator : public exec::CastOperator { public: - bool isSupportedFromType(const TypePtr& other) const override; - - bool isSupportedToType(const TypePtr& other) const override; - void castTo( const BaseVector& input, exec::EvalCtx& context, diff --git a/velox/functions/prestosql/types/JsonRegistration.cpp b/velox/functions/prestosql/types/JsonRegistration.cpp index 4952aa0360e..572b273b710 100644 --- a/velox/functions/prestosql/types/JsonRegistration.cpp +++ b/velox/functions/prestosql/types/JsonRegistration.cpp @@ -19,6 +19,7 @@ #include "velox/common/fuzzer/ConstrainedGenerators.h" #include "velox/functions/prestosql/types/JsonCastOperator.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/type/CastRegistry.h" #include "velox/type/Type.h" namespace facebook::velox { @@ -60,9 +61,205 @@ class JsonTypeFactory : public CustomTypeFactory { false); } }; + +// Returns true if 'type' is a primitive scalar type usable as a MAP key in +// JSON casts. JSON itself has VARCHAR kind so it passes here; callers that +// need to reject JSON keys (e.g. JSON → MAP) check isJsonType separately. +bool isValidJsonMapKey(const TypePtr& type) { + switch (type->kind()) { + case TypeKind::BOOLEAN: + case TypeKind::BIGINT: + case TypeKind::INTEGER: + case TypeKind::SMALLINT: + case TypeKind::TINYINT: + case TypeKind::DOUBLE: + case TypeKind::REAL: + case TypeKind::VARCHAR: + return true; + default: + return false; + } +} + +// Validator for DECIMAL → JSON: only short decimals (precision ≤ 18) are +// supported. Short decimals use TypeKind::BIGINT; long decimals use HUGEINT. +bool canCastDecimalToJson(const TypePtr& from, const TypePtr& /*to*/) { + return from->kind() == TypeKind::BIGINT; +} + +// Validator for ARRAY → JSON: element type must be castable to JSON. +bool canCastArrayToJson(const TypePtr& from, const TypePtr& to) { + return CastRulesRegistry::instance().canCast(from->childAt(0), to); +} + +// Validator for ROW → JSON: all child types must be castable to JSON. +bool canCastRowToJson(const TypePtr& from, const TypePtr& to) { + for (auto i = 0; i < from->size(); ++i) { + if (!CastRulesRegistry::instance().canCast(from->childAt(i), to)) { + return false; + } + } + return true; +} + +// Validator for MAP → JSON: key must be a valid JSON map key, value must be +// castable to JSON. Special case: MAP is allowed. +// Note: JSON keys are allowed here (JSON has VARCHAR kind and passes +// isValidJsonMapKey). The old isSupportedFromType had an inconsistency where +// JSON keys were rejected only in the UNKNOWN-value branch but allowed +// otherwise; we use the more permissive behavior consistently. +bool canCastMapToJson(const TypePtr& from, const TypePtr& to) { + const auto& keyType = from->childAt(0); + const auto& valueType = from->childAt(1); + if (keyType->isUnknown() && valueType->isUnknown()) { + return true; + } + if (!isValidJsonMapKey(keyType)) { + return false; + } + return CastRulesRegistry::instance().canCast(valueType, to); +} + +// Validator for JSON → ARRAY: element type must be castable from JSON. +bool canCastJsonToArray(const TypePtr& from, const TypePtr& to) { + return CastRulesRegistry::instance().canCast(from, to->childAt(0)); +} + +// Validator for JSON → ROW: all child types must be castable from JSON. +bool canCastJsonToRow(const TypePtr& from, const TypePtr& to) { + for (auto i = 0; i < to->size(); ++i) { + if (!CastRulesRegistry::instance().canCast(from, to->childAt(i))) { + return false; + } + } + return true; +} + +// Validator for JSON → MAP: key must be a valid JSON map key and not JSON +// itself, value must be castable from JSON. +bool canCastJsonToMap(const TypePtr& from, const TypePtr& to) { + const auto& keyType = to->childAt(0); + const auto& valueType = to->childAt(1); + if (!isValidJsonMapKey(keyType) || isJsonType(keyType)) { + return false; + } + return CastRulesRegistry::instance().canCast(from, valueType); +} + } // namespace void registerJsonType() { registerCustomType("json", std::make_unique()); + registerCastRules({ + // TO JSON (from primitive types). + {.fromType = "UNKNOWN", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "BOOLEAN", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "TINYINT", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "SMALLINT", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "INTEGER", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "BIGINT", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "REAL", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "DOUBLE", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "VARCHAR", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "TIMESTAMP", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "DATE", + .toType = "JSON", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "DECIMAL", + .toType = "JSON", + .implicitAllowed = false, + .validator = canCastDecimalToJson}, + // TO JSON (from container types with recursive validation). + {.fromType = "ARRAY", + .toType = "JSON", + .implicitAllowed = false, + .validator = canCastArrayToJson}, + {.fromType = "ROW", + .toType = "JSON", + .implicitAllowed = false, + .validator = canCastRowToJson}, + {.fromType = "MAP", + .toType = "JSON", + .implicitAllowed = false, + .validator = canCastMapToJson}, + // FROM JSON (to primitive types). + // Note: JSON -> TIMESTAMP is not supported in Presto. + {.fromType = "JSON", + .toType = "BOOLEAN", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "JSON", + .toType = "TINYINT", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "JSON", + .toType = "SMALLINT", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "JSON", + .toType = "INTEGER", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "JSON", + .toType = "BIGINT", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "JSON", + .toType = "REAL", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "JSON", + .toType = "DOUBLE", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "JSON", + .toType = "VARCHAR", + .implicitAllowed = false, + .validator = {}}, + // FROM JSON (to container types with recursive validation). + {.fromType = "JSON", + .toType = "ARRAY", + .implicitAllowed = false, + .validator = canCastJsonToArray}, + {.fromType = "JSON", + .toType = "ROW", + .implicitAllowed = false, + .validator = canCastJsonToRow}, + {.fromType = "JSON", + .toType = "MAP", + .implicitAllowed = false, + .validator = canCastJsonToMap}, + }); } } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/P4HyperLogLogRegistration.cpp b/velox/functions/prestosql/types/P4HyperLogLogRegistration.cpp index 34212fcc1d5..e09f98d92b0 100644 --- a/velox/functions/prestosql/types/P4HyperLogLogRegistration.cpp +++ b/velox/functions/prestosql/types/P4HyperLogLogRegistration.cpp @@ -22,20 +22,13 @@ #include "velox/functions/prestosql/types/HyperLogLogType.h" #include "velox/functions/prestosql/types/P4HyperLogLogType.h" #include "velox/functions/prestosql/types/fuzzer_utils/P4HyperLogLogInputGenerator.h" +#include "velox/type/CastRegistry.h" namespace facebook::velox { namespace { class P4HyperLogLogCastOperator : public exec::CastOperator { public: - bool isSupportedFromType(const TypePtr& other) const override { - return other->equivalent(*HYPERLOGLOG()); - } - - bool isSupportedToType(const TypePtr& other) const override { - return other->equivalent(*HYPERLOGLOG()); - } - void castTo( const BaseVector& input, exec::EvalCtx& context, @@ -161,5 +154,15 @@ class P4HyperLogLogTypeFactory : public CustomTypeFactory { void registerP4HyperLogLogType() { registerCustomType( "p4hyperloglog", std::make_unique()); + registerCastRules({ + {.fromType = "HYPERLOGLOG", + .toType = "P4HYPERLOGLOG", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "P4HYPERLOGLOG", + .toType = "HYPERLOGLOG", + .implicitAllowed = false, + .validator = {}}, + }); } } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/TimeWithTimezoneRegistration.cpp b/velox/functions/prestosql/types/TimeWithTimezoneRegistration.cpp index 05cf028bb2c..33fef253e48 100644 --- a/velox/functions/prestosql/types/TimeWithTimezoneRegistration.cpp +++ b/velox/functions/prestosql/types/TimeWithTimezoneRegistration.cpp @@ -158,28 +158,6 @@ class TimeWithTimeZoneCastOperator final : public exec::CastOperator { return {std::shared_ptr{}, &kInstance}; } - bool isSupportedFromType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::BIGINT: - return other->equivalent(*TIME()); - case TypeKind::VARCHAR: - return true; - default: - return false; - } - } - - bool isSupportedToType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::BIGINT: - return other->equivalent(*TIME()); - case TypeKind::VARCHAR: - return true; - default: - return false; - } - } - void castTo( const BaseVector& input, exec::EvalCtx& context, diff --git a/velox/functions/prestosql/types/TimestampWithTimeZoneRegistration.cpp b/velox/functions/prestosql/types/TimestampWithTimeZoneRegistration.cpp index 2c2017d700a..825f4d262cd 100644 --- a/velox/functions/prestosql/types/TimestampWithTimeZoneRegistration.cpp +++ b/velox/functions/prestosql/types/TimestampWithTimeZoneRegistration.cpp @@ -248,36 +248,6 @@ class TimestampWithTimeZoneCastOperator final : public exec::CastOperator { return {std::shared_ptr{}, &kInstance}; } - bool isSupportedFromType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::TIMESTAMP: - return true; - case TypeKind::VARCHAR: - return true; - case TypeKind::INTEGER: - return other->isDate(); - case TypeKind::BIGINT: - return other->equivalent(*TIME()); - default: - return false; - } - } - - bool isSupportedToType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::TIMESTAMP: - return true; - case TypeKind::VARCHAR: - return true; - case TypeKind::INTEGER: - return other->isDate(); - case TypeKind::BIGINT: - return other->equivalent(*TIME()); - default: - return false; - } - } - void castTo( const BaseVector& input, exec::EvalCtx& context, diff --git a/velox/functions/prestosql/types/UuidRegistration.cpp b/velox/functions/prestosql/types/UuidRegistration.cpp index 00778521156..2ca20df32eb 100644 --- a/velox/functions/prestosql/types/UuidRegistration.cpp +++ b/velox/functions/prestosql/types/UuidRegistration.cpp @@ -22,6 +22,7 @@ #include "velox/expression/CastExpr.h" #include "velox/functions/prestosql/types/UuidType.h" +#include "velox/type/CastRegistry.h" namespace facebook::velox { namespace { @@ -51,14 +52,6 @@ struct UuidParser { class UuidCastOperator : public exec::CastOperator { public: - bool isSupportedFromType(const TypePtr& other) const override { - return VARCHAR()->equivalent(*other) || VARBINARY()->equivalent(*other); - } - - bool isSupportedToType(const TypePtr& other) const override { - return VARCHAR()->equivalent(*other) || VARBINARY()->equivalent(*other); - } - void castTo( const BaseVector& input, exec::EvalCtx& context, @@ -161,5 +154,23 @@ class UuidTypeFactory : public CustomTypeFactory { void registerUuidType() { registerCustomType("uuid", std::make_unique()); + registerCastRules({ + {.fromType = "VARCHAR", + .toType = "UUID", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "VARBINARY", + .toType = "UUID", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "UUID", + .toType = "VARCHAR", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "UUID", + .toType = "VARBINARY", + .implicitAllowed = false, + .validator = {}}, + }); } } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/VarcharEnumRegistration.cpp b/velox/functions/prestosql/types/VarcharEnumRegistration.cpp index 251450e0fa3..debe7485d2b 100644 --- a/velox/functions/prestosql/types/VarcharEnumRegistration.cpp +++ b/velox/functions/prestosql/types/VarcharEnumRegistration.cpp @@ -17,6 +17,7 @@ #include "velox/functions/prestosql/types/VarcharEnumRegistration.h" #include "velox/expression/CastExpr.h" #include "velox/functions/prestosql/types/VarcharEnumType.h" +#include "velox/type/CastRegistry.h" namespace facebook::velox { namespace { @@ -29,16 +30,6 @@ class VarcharEnumCastOperator : public exec::CastOperator { return kInstance; } - // Casting is only supported from VARCHAR type. - bool isSupportedFromType(const TypePtr& other) const override { - return VARCHAR()->equivalent(*other); - } - - // Casting is only supported to VARCHAR type. - bool isSupportedToType(const TypePtr& other) const override { - return VARCHAR()->equivalent(*other); - } - void castTo( const BaseVector& input, exec::EvalCtx& context, @@ -103,5 +94,15 @@ class VarcharEnumTypeFactory : public CustomTypeFactory { void registerVarcharEnumType() { registerCustomType( "varchar_enum", std::make_unique()); + registerCastRules({ + {.fromType = "VARCHAR", + .toType = "VARCHAR_ENUM", + .implicitAllowed = false, + .validator = {}}, + {.fromType = "VARCHAR_ENUM", + .toType = "VARCHAR", + .implicitAllowed = false, + .validator = {}}, + }); } } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/tests/IPPrefixTypeTest.cpp b/velox/functions/prestosql/types/tests/IPPrefixTypeTest.cpp index 432b100cd75..a27bf777828 100644 --- a/velox/functions/prestosql/types/tests/IPPrefixTypeTest.cpp +++ b/velox/functions/prestosql/types/tests/IPPrefixTypeTest.cpp @@ -29,7 +29,6 @@ class IPPrefixTypeTest : public testing::Test, public TypeTestBase { TEST_F(IPPrefixTypeTest, basic) { ASSERT_STREQ(IPPREFIX()->name(), "IPPREFIX"); ASSERT_STREQ(IPPREFIX()->kindName(), "ROW"); - ASSERT_EQ(IPPREFIX()->name(), "IPPREFIX"); ASSERT_TRUE(IPPREFIX()->parameters().empty()); ASSERT_TRUE(hasType("IPPREFIX")); diff --git a/velox/functions/sparksql/types/TimestampNTZRegistration.cpp b/velox/functions/sparksql/types/TimestampNTZRegistration.cpp index 6dd7a764d2f..77b7b7c89af 100644 --- a/velox/functions/sparksql/types/TimestampNTZRegistration.cpp +++ b/velox/functions/sparksql/types/TimestampNTZRegistration.cpp @@ -19,6 +19,7 @@ #include "velox/expression/CastExpr.h" #include "velox/functions/sparksql/types/TimestampNTZCastUtil.h" #include "velox/functions/sparksql/types/TimestampNTZType.h" +#include "velox/type/CastRegistry.h" namespace facebook::velox::functions::sparksql { namespace { @@ -32,21 +33,6 @@ class TimestampNTZCastOperator final : public exec::CastOperator { return {std::shared_ptr{}, &kInstance}; } - // Returns true if casting from other type to TIMESTAMP_NTZ type is supported. - bool isSupportedFromType(const TypePtr& other) const override { - switch (other->kind()) { - case TypeKind::VARCHAR: - return true; - default: - return false; - } - } - - // Return true if casting from TIMESTAMP_NTZ type to other type is supported. - bool isSupportedToType(const TypePtr& other) const override { - return false; - } - // Casts the input vector to the TIMESTAMP_NTZ type. void castTo( const BaseVector& input, @@ -106,6 +92,12 @@ class TimestampNTZTypeFactory : public CustomTypeFactory { void registerTimestampNTZType() { registerCustomType( "timestamp_ntz", std::make_unique()); + registerCastRules({ + {.fromType = "VARCHAR", + .toType = "TIMESTAMP_NTZ", + .implicitAllowed = false, + .validator = {}}, + }); } } // namespace facebook::velox::functions::sparksql diff --git a/velox/type/CastRegistry.cpp b/velox/type/CastRegistry.cpp index 78c71a067be..309ffafecea 100644 --- a/velox/type/CastRegistry.cpp +++ b/velox/type/CastRegistry.cpp @@ -84,12 +84,10 @@ std::optional CastRulesRegistry::castCostImpl( const auto fromName = fromType->name(); const auto toName = toType->name(); - // For leaf types (primitives and custom types), look up rule directly. - if (fromType->size() == 0 && toType->size() == 0) { - auto rule = findRule(fromName, toName); - if (!rule) { - return std::nullopt; - } + // Try direct rule lookup first. This handles primitives and custom types + // including those with children like IPPREFIX (which extends RowType). + auto rule = findRule(fromName, toName); + if (rule) { if (requireImplicit && !rule->implicitAllowed) { return std::nullopt; } @@ -99,9 +97,9 @@ std::optional CastRulesRegistry::castCostImpl( return requireImplicit ? rule->cost : 0; } - // For container types (ARRAY, MAP, ROW) with the same base type, - // recursively check children and sum costs. - if (fromName == toName) { + // No explicit rule. For container types (ARRAY, MAP, ROW) with the same + // base type, recursively check children and sum costs. + if (fromName == toName && fromType->size() > 0) { if (fromType->size() != toType->size()) { return std::nullopt; } @@ -138,13 +136,17 @@ void CastRulesRegistry::clear() { void registerCastRules(const std::vector& rules) { for (const auto& rule : rules) { - // At least one side must be a registered custom type with a CastOperator, - // otherwise the cast has no execution path. + // Validate that both type names are known. For custom types, at least one + // side must have a CastOperator to provide the cast execution path. + // Built-in type pairs (e.g. VARCHAR→DOUBLE) are handled by Velox's native + // cast machinery in CastExpr, so they don't need a CastOperator. + bool hasCustomCast = getCustomTypeCastOperator(rule.fromType) != nullptr || + getCustomTypeCastOperator(rule.toType) != nullptr; + bool bothKnown = hasType(rule.fromType) && hasType(rule.toType); VELOX_CHECK( - getCustomTypeCastOperator(rule.fromType) != nullptr || - getCustomTypeCastOperator(rule.toType) != nullptr, - "CastRule {} -> {} requires at least one side to be a registered " - "custom type with a CastOperator", + hasCustomCast || bothKnown, + "CastRule {} -> {} requires either a CastOperator on at least one " + "side, or both types to be known built-in types", rule.fromType, rule.toType); } diff --git a/velox/type/CastRegistry.h b/velox/type/CastRegistry.h index 9289717a593..adfbc8213e6 100644 --- a/velox/type/CastRegistry.h +++ b/velox/type/CastRegistry.h @@ -92,9 +92,10 @@ class CastRulesRegistry { rules_; }; -/// Register cast rules for custom types. Call this after registerCustomType() -/// in register*Type() functions. Validates that at least one type in each rule -/// has a registered CastOperator. +/// Register cast rules. For custom types, call this after +/// registerCustomType() in register*Type() functions. Validates that either +/// at least one type has a registered CastOperator, or both types are known +/// built-in types (handled by Velox's native cast machinery in CastExpr). /// /// Example: /// registerCastRules({ diff --git a/velox/type/Type.cpp b/velox/type/Type.cpp index d086641e86a..b66e5c9b6f7 100644 --- a/velox/type/Type.cpp +++ b/velox/type/Type.cpp @@ -1482,6 +1482,15 @@ bool hasType(const std::string& name) { return false; } +TypePtr getScalarType(const std::string& name) { + auto it = singletonBuiltInTypes().find(name); + if (it != singletonBuiltInTypes().end()) { + return it->second; + } + + return getCustomType(name, {}); +} + TypePtr getType( const std::string& name, const std::vector& parameters) { diff --git a/velox/type/Type.h b/velox/type/Type.h index 98b7e2c395a..0627d0b6b0c 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -2233,6 +2233,12 @@ TypePtr createType(TypeKind kind, std::vector&& children); /// Returns true built-in or custom type with specified name exists. bool hasType(const std::string& name); +/// Returns a non-parametric type by name: built-in scalar singletons (BOOLEAN, +/// INTEGER, VARCHAR, etc.) and custom types (JSON, UUID, etc.). Returns nullptr +/// if the name is unknown or refers to a parametric type (ARRAY, MAP, ROW, +/// DECIMAL). Never throws. +TypePtr getScalarType(const std::string& name); + /// Returns built-in or custom type with specified name and child types. /// Returns nullptr if type with specified name doesn't exist. TypePtr getType( diff --git a/velox/type/TypeCoercer.cpp b/velox/type/TypeCoercer.cpp index b289ea02732..b26174b24c7 100644 --- a/velox/type/TypeCoercer.cpp +++ b/velox/type/TypeCoercer.cpp @@ -32,59 +32,95 @@ int64_t Coercion::overallCost(const std::vector& coercions) { namespace { -std::unordered_map, Coercion> -allowedCoercions() { - std::unordered_map, Coercion> coercions; - - auto add = [&](const TypePtr& from, const std::vector& to) { +// Registers implicit coercion rules for built-in types in the +// CastRulesRegistry. Idempotent — safe to call from multiple overloads. +void registerBuiltInCoercions() { + static bool registered = false; + if (registered) { + return; + } + registered = true; + + auto add = [](const std::string& from, + const std::vector& toTypes) { + std::vector rules; int32_t cost = 0; - for (const auto& toType : to) { - coercions.emplace( - std::make_pair( - from->name(), toType->name()), - Coercion{.type = toType, .cost = ++cost}); + for (const auto& to : toTypes) { + rules.push_back( + {from, + to, + /*implicitAllowed=*/true, + /*cost=*/++cost, + /*validator=*/nullptr}); } + registerCastRules(rules); }; - add(TINYINT(), {SMALLINT(), INTEGER(), BIGINT(), REAL(), DOUBLE()}); - add(SMALLINT(), {INTEGER(), BIGINT(), REAL(), DOUBLE()}); - add(INTEGER(), {BIGINT(), REAL(), DOUBLE()}); - add(BIGINT(), {DOUBLE()}); - add(REAL(), {DOUBLE()}); - add(DATE(), {TIMESTAMP()}); - add(UNKNOWN(), - {TINYINT(), - BOOLEAN(), - SMALLINT(), - INTEGER(), - BIGINT(), - REAL(), - DOUBLE(), - VARCHAR(), - VARBINARY()}); - - return coercions; + add("TINYINT", {"SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE"}); + add("SMALLINT", {"INTEGER", "BIGINT", "REAL", "DOUBLE"}); + add("INTEGER", {"BIGINT", "REAL", "DOUBLE"}); + add("BIGINT", {"DOUBLE"}); + add("REAL", {"DOUBLE"}); + add("DATE", {"TIMESTAMP"}); + add("UNKNOWN", + {"TINYINT", + "BOOLEAN", + "SMALLINT", + "INTEGER", + "BIGINT", + "REAL", + "DOUBLE", + "VARCHAR", + "VARBINARY"}); } + } // namespace // static std::optional TypeCoercer::coerceTypeBase( const TypePtr& fromType, - const std::string& toTypeName) { - static const auto kAllowedCoercions = allowedCoercions(); - if (fromType->name() == toTypeName) { + const TypePtr& toType) { + static const bool kRegistered = [] { + registerBuiltInCoercions(); + return true; + }(); + (void)kRegistered; + + if (fromType->name() == toType->name()) { return Coercion{.type = fromType, .cost = 0}; } - // Check built-in coercions first. - auto it = kAllowedCoercions.find({fromType->name(), toTypeName}); - if (it != kAllowedCoercions.end()) { - return it->second; + // CastRulesRegistry is the single source of truth for all coercions + // (built-in and custom). + if (auto cost = CastRulesRegistry::instance().canCoerce(fromType, toType)) { + return Coercion{.type = toType, .cost = *cost}; + } + + return std::nullopt; +} + +// static +std::optional TypeCoercer::coerceTypeBase( + const TypePtr& fromType, + const std::string& toTypeName) { + static const bool kRegistered = [] { + registerBuiltInCoercions(); + return true; + }(); + (void)kRegistered; + + if (fromType->name() == toTypeName) { + return Coercion{.type = fromType, .cost = 0}; } - // Fall back to CastRulesRegistry for custom type coercions. + // Look up coercion from the CastRulesRegistry (covers both built-in and + // custom type coercions). getScalarType() resolves built-in singletons and + // non-parametric custom types; returns nullptr for parametric types (ARRAY, + // MAP, ROW) and unknown names. SignatureBinder guards against parametric + // custom types (e.g., BIGINT_ENUM) by checking + // typeSignature.parameters().empty(). if (fromType->size() == 0) { - auto toType = getCustomType(toTypeName, {}); + auto toType = getScalarType(toTypeName); if (toType != nullptr) { if (auto cost = CastRulesRegistry::instance().canCoerce(fromType, toType)) { @@ -108,7 +144,7 @@ std::optional TypeCoercer::coercible( } if (fromType->size() == 0) { - if (auto coercion = TypeCoercer::coerceTypeBase(fromType, toType->name())) { + if (auto coercion = TypeCoercer::coerceTypeBase(fromType, toType)) { return coercion->cost; } @@ -179,11 +215,11 @@ TypePtr TypeCoercer::leastCommonSuperType(const TypePtr& a, const TypePtr& b) { } if (a->size() == 0) { - if (TypeCoercer::coerceTypeBase(a, b->name())) { + if (TypeCoercer::coerceTypeBase(a, b)) { return b; } - if (TypeCoercer::coerceTypeBase(b, a->name())) { + if (TypeCoercer::coerceTypeBase(b, a)) { return a; } diff --git a/velox/type/TypeCoercer.h b/velox/type/TypeCoercer.h index 53052fd2e7f..ee756e8a602 100644 --- a/velox/type/TypeCoercer.h +++ b/velox/type/TypeCoercer.h @@ -73,8 +73,19 @@ struct Coercion { class TypeCoercer { public: + /// Checks if 'fromType' can be implicitly converted to 'toType'. + /// + /// Prefer this over the string overload when the target type is available, + /// as it avoids reconstructing the type from a name (which can throw for + /// parametric custom types like BigintEnum). + /// + /// @return "to" type and cost if conversion is possible. + static std::optional coerceTypeBase( + const TypePtr& fromType, + const TypePtr& toType); + /// Checks if the base of 'fromType' can be implicitly converted to a type - /// with the given name. + /// with the given name. Used by SignatureBinder which only has a type name. /// /// @return "to" type and cost if conversion is possible. static std::optional coerceTypeBase( diff --git a/velox/type/tests/CastRegistryTest.cpp b/velox/type/tests/CastRegistryTest.cpp index 46b9b76b265..91bc29d0760 100644 --- a/velox/type/tests/CastRegistryTest.cpp +++ b/velox/type/tests/CastRegistryTest.cpp @@ -405,9 +405,9 @@ TEST_F(CastRulesRegistryTest, standaloneRegistration) { })); } -TEST_F(CastRulesRegistryTest, rejectsRulesWithoutCastOperator) { - // The free function registerCastRules() validates that at least one side - // has a registered custom type with a CastOperator. +TEST_F(CastRulesRegistryTest, rejectsRulesWithUnknownTypes) { + // Rules where one side is unknown and neither has a CastOperator are + // rejected. EXPECT_THROW( registerCastRules({ {.fromType = "BIGINT", @@ -418,5 +418,17 @@ TEST_F(CastRulesRegistryTest, rejectsRulesWithoutCastOperator) { VeloxException); } +TEST_F(CastRulesRegistryTest, acceptsBuiltInToBuiltInRules) { + // Rules between known built-in types are accepted even without a + // CastOperator — Velox's native cast machinery handles them. + EXPECT_NO_THROW(registerCastRules({ + {.fromType = "VARCHAR", + .toType = "DOUBLE", + .implicitAllowed = true, + .cost = 5, + .validator = {}}, + })); +} + } // namespace } // namespace facebook::velox diff --git a/velox/type/tests/TypeCoercerTest.cpp b/velox/type/tests/TypeCoercerTest.cpp index 0712c202d6d..8cff410af56 100644 --- a/velox/type/tests/TypeCoercerTest.cpp +++ b/velox/type/tests/TypeCoercerTest.cpp @@ -219,9 +219,35 @@ TEST(TypeCoercerTest, leastCommonSuperType) { MAP(INTEGER(), REAL()), ROW({INTEGER(), REAL()})) == nullptr); } -TEST(TypeCoercerTest, parametricBuiltinTargetDoesNotThrow) { - // Parametric built-in factories throw on empty params. Verify graceful - // handling. +TEST(TypeCoercerTest, customRuleCostPropagation) { + // Verify that CastRulesRegistry costs propagate through TypeCoercer. + // clear() wipes all rules (including built-in coercions registered once via + // function-local static), so only the custom rule below is active. + // VARCHAR→DOUBLE is not a built-in coercion, ensuring this tests the + // custom-rule path. + CastRulesRegistry::instance().clear(); + registerCastRules({ + {.fromType = "VARCHAR", + .toType = "DOUBLE", + .implicitAllowed = true, + .cost = 7, + .validator = {}}, + }); + + auto coercion = TypeCoercer::coerceTypeBase(VARCHAR(), "DOUBLE"); + ASSERT_TRUE(coercion.has_value()); + EXPECT_EQ(coercion->cost, 7); + EXPECT_EQ(*coercion->type, *DOUBLE()); + + auto cost = TypeCoercer::coercible(VARCHAR(), DOUBLE()); + ASSERT_TRUE(cost.has_value()); + EXPECT_EQ(cost.value(), 7); + + CastRulesRegistry::instance().clear(); +} + +TEST(TypeCoercerTest, parametricBuiltinTargetReturnsNullopt) { + // Parametric types (ARRAY, MAP, ROW) are not valid leaf coercion targets. EXPECT_EQ(TypeCoercer::coerceTypeBase(BIGINT(), "ARRAY"), std::nullopt); }