Skip to content

Commit 4d93698

Browse files
abhinavmuk04meta-codesync[bot]
authored andcommitted
Add the dot_product UDF (facebookincubator#15971)
Summary: Pull Request resolved: facebookincubator#15971 Add a DOT_PRODUCT function to compute the dot product between two arrays or maps, enabling efficient vector operations for ML and analytics workloads. ## Function Signatures - `dot_product(array(T), array(T)) -> bigint/double` - `dot_product(map(K, V), map(K, V)) -> bigint/double` ## Key Features - **Array dot product**: Computes sum of element-wise products of corresponding elements - **Map dot product**: Multiplies values with matching keys and sums the results - **Equal length requirement**: Arrays must have the same length (throws error otherwise) - **Null handling**: Null elements are treated as zero, null arguments return null - **Type support**: int8, int16, int32, int64, float, double for arrays; integer/varchar keys with integer/double values for maps - **Overflow protection**: Uses checked arithmetic for integer operations ## Behavior Examples ```sql SELECT dot_product(ARRAY[1, 2, 3], ARRAY[4, 5, 6]); -- 32 (1*4 + 2*5 + 3*6) SELECT dot_product(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]); -- 11.0 SELECT dot_product(ARRAY[1, NULL, 3], ARRAY[4, 5, 6]); -- 22 (nulls treated as 0) SELECT dot_product(MAP(ARRAY[1, 2], ARRAY[10, 20]), MAP(ARRAY[1, 2], ARRAY[3, 4])); -- 110 ``` Reviewed By: zacw7 Differential Revision: D90476184 fbshipit-source-id: 8bffe8a70c52748d62d9fb36e039ad89291006db
1 parent 1038972 commit 4d93698

File tree

6 files changed

+831
-0
lines changed

6 files changed

+831
-0
lines changed

velox/docs/functions/presto/array.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,30 @@ Array Functions
272272
SELECT l2_norm(MAP(ARRAY[1, 2], ARRAY[3.0, 4.0])); -- 5.0
273273
SELECT l2_norm(MAP(ARRAY[], ARRAY[])); -- 0.0
274274

275+
.. function:: dot_product(array(T), array(T)) -> bigint/double
276+
277+
Computes the dot product of two arrays. The dot product is the sum of element-wise
278+
products of corresponding elements. Both arrays must have the same length.
279+
If either array is null, returns null. If arrays have different lengths, throws an error.
280+
Null elements in arrays are treated as zero.
281+
Returns bigint for integer arrays, double for floating-point arrays. ::
282+
283+
SELECT dot_product(ARRAY[1, 2, 3], ARRAY[4, 5, 6]); -- 32 (1*4 + 2*5 + 3*6)
284+
SELECT dot_product(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]); -- 11.0 (1.0*3.0 + 2.0*4.0)
285+
SELECT dot_product(ARRAY[1, NULL, 3], ARRAY[4, 5, 6]); -- 22 (1*4 + 0*5 + 3*6)
286+
SELECT dot_product(ARRAY[], ARRAY[]); -- 0
287+
288+
.. function:: dot_product(map(K, V), map(K, V)) -> bigint/double
289+
290+
Computes the dot product of two maps. For maps, the dot product is computed by
291+
multiplying values with matching keys and summing the results. Keys present in only
292+
one map contribute zero to the result. If either map is null, returns null.
293+
Null values in maps are treated as zero.
294+
Returns bigint for integer value maps, double for floating-point value maps. ::
295+
296+
SELECT dot_product(MAP(ARRAY[1, 2], ARRAY[10, 20]), MAP(ARRAY[1, 2], ARRAY[3, 4])); -- 110 (10*3 + 20*4)
297+
SELECT dot_product(MAP(ARRAY['a', 'b'], ARRAY[1.0, 2.0]), MAP(ARRAY['a', 'c'], ARRAY[3.0, 4.0])); -- 3.0 (only 'a' matches)
298+
275299
.. function:: array_sum(array(T)) -> bigint/double
276300

277301
Returns the sum of all non-null elements of the array. If there is no non-null elements, returns 0. The behaviour is similar to aggregation function sum().

velox/expression/fuzzer/ExpressionFuzzerTest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ std::unordered_set<std::string> skipFunctionsSOT = {
302302
"array_subset", // Velox-only function, not available in Presto
303303
"map_values_in_range", // Velox-only function, not available in Presto
304304
"transform_with_index", // Velox-only function, not available in Presto
305+
"dot_product", // Velox-only function, not available in Presto
305306
"remap_keys", // Velox-only function, not available in Presto
306307
"map_intersect", // Velox-only function, not available in Presto
307308
"map_keys_overlap", // Velox-only function, not available in Presto
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <cstdint>
19+
20+
#include <folly/CPortability.h>
21+
#include <folly/container/F14Map.h>
22+
23+
#include "velox/common/base/Exceptions.h"
24+
#include "velox/functions/Macros.h"
25+
#include "velox/functions/lib/CheckedArithmetic.h"
26+
#include "velox/type/SimpleFunctionApi.h"
27+
28+
namespace facebook::velox::functions {
29+
30+
/// Computes the dot product of two arrays.
31+
/// The dot product is the sum of element-wise products of corresponding
32+
/// elements. Both arrays must have the same length. If either array is null,
33+
/// returns null. If arrays have different lengths, throws an error.
34+
/// Null elements in arrays are treated as zero.
35+
template <typename TExec, typename T>
36+
struct DotProductFunction {
37+
VELOX_DEFINE_FUNCTION_TYPES(TExec);
38+
39+
template <typename TOutput>
40+
FOLLY_ALWAYS_INLINE bool call(
41+
TOutput& out,
42+
const arg_type<Array<T>>& array1,
43+
const arg_type<Array<T>>& array2) {
44+
const auto size1 = array1.size();
45+
const auto size2 = array2.size();
46+
47+
VELOX_USER_CHECK_EQ(
48+
size1,
49+
size2,
50+
"dot_product requires arrays of equal length, but got {} and {}",
51+
size1,
52+
size2);
53+
54+
TOutput sum = 0;
55+
for (vector_size_t i = 0; i < size1; ++i) {
56+
const auto& val1 = array1[i];
57+
const auto& val2 = array2[i];
58+
59+
if (val1.has_value() && val2.has_value()) {
60+
if constexpr (std::is_same_v<TOutput, int64_t>) {
61+
auto product = checkedMultiply<TOutput>(
62+
static_cast<TOutput>(val1.value()),
63+
static_cast<TOutput>(val2.value()));
64+
sum = checkedPlus<TOutput>(sum, product);
65+
} else {
66+
sum += static_cast<TOutput>(val1.value()) *
67+
static_cast<TOutput>(val2.value());
68+
}
69+
}
70+
}
71+
out = sum;
72+
return true;
73+
}
74+
};
75+
76+
/// Computes the dot product of two maps.
77+
/// For maps, the dot product is computed by multiplying values with matching
78+
/// keys and summing the results. Keys present in only one map contribute zero
79+
/// to the result. If either map is null, returns null.
80+
/// Null values in maps are treated as zero.
81+
template <typename TExec, typename K, typename V>
82+
struct MapDotProductFunction {
83+
VELOX_DEFINE_FUNCTION_TYPES(TExec);
84+
85+
template <typename TOutput>
86+
FOLLY_ALWAYS_INLINE bool call(
87+
TOutput& out,
88+
const arg_type<Map<K, V>>& map1,
89+
const arg_type<Map<K, V>>& map2) {
90+
TOutput sum = 0;
91+
92+
// Build a lookup map from map2 for O(1) key lookup.
93+
// This reduces complexity from O(n*m) to O(n+m).
94+
folly::F14FastMap<arg_type<K>, arg_type<V>> map2Lookup;
95+
map2Lookup.reserve(map2.size());
96+
for (const auto& [key, val] : map2) {
97+
if (val.has_value()) {
98+
map2Lookup.emplace(key, val.value());
99+
}
100+
}
101+
102+
// Iterate through map1 and look up matching keys in map2.
103+
for (const auto& [key1, val1] : map1) {
104+
if (!val1.has_value()) {
105+
continue;
106+
}
107+
108+
auto it = map2Lookup.find(key1);
109+
if (it != map2Lookup.end()) {
110+
if constexpr (std::is_same_v<TOutput, int64_t>) {
111+
auto product = checkedMultiply<TOutput>(
112+
static_cast<TOutput>(val1.value()),
113+
static_cast<TOutput>(it->second));
114+
sum = checkedPlus<TOutput>(sum, product);
115+
} else {
116+
sum += static_cast<TOutput>(val1.value()) *
117+
static_cast<TOutput>(it->second);
118+
}
119+
}
120+
}
121+
122+
out = sum;
123+
return true;
124+
}
125+
};
126+
127+
} // namespace facebook::velox::functions

velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "velox/functions/prestosql/ArrayFunctions.h"
2727
#include "velox/functions/prestosql/ArraySort.h"
2828
#include "velox/functions/prestosql/ArraySubset.h"
29+
#include "velox/functions/prestosql/DotProduct.h"
2930
#include "velox/functions/prestosql/L2Norm.h"
3031
#include "velox/functions/prestosql/WidthBucketArray.h"
3132
#include "velox/functions/prestosql/types/JsonRegistration.h"
@@ -448,5 +449,66 @@ void registerArrayFunctions(const std::string& prefix) {
448449
ParameterBinder<MapL2NormFunction, int64_t, double>,
449450
double,
450451
Map<int64_t, double>>({prefix + "l2_norm"});
452+
453+
// Register dot_product for integer arrays only.
454+
// Float and double array versions already exist in
455+
// MathematicalFunctionsRegistration.cpp (DotProductArray,
456+
// DotProductFloatArray) with different semantics: they return NaN for empty
457+
// arrays to maintain compatibility with cosine_similarity and other distance
458+
// functions there. Integer versions here return 0 for empty arrays.
459+
registerFunction<
460+
ParameterBinder<DotProductFunction, int8_t>,
461+
int64_t,
462+
Array<int8_t>,
463+
Array<int8_t>>({prefix + "dot_product"});
464+
registerFunction<
465+
ParameterBinder<DotProductFunction, int16_t>,
466+
int64_t,
467+
Array<int16_t>,
468+
Array<int16_t>>({prefix + "dot_product"});
469+
registerFunction<
470+
ParameterBinder<DotProductFunction, int32_t>,
471+
int64_t,
472+
Array<int32_t>,
473+
Array<int32_t>>({prefix + "dot_product"});
474+
registerFunction<
475+
ParameterBinder<DotProductFunction, int64_t>,
476+
int64_t,
477+
Array<int64_t>,
478+
Array<int64_t>>({prefix + "dot_product"});
479+
480+
// Register dot_product for maps with integer keys
481+
registerFunction<
482+
ParameterBinder<MapDotProductFunction, int32_t, int64_t>,
483+
int64_t,
484+
Map<int32_t, int64_t>,
485+
Map<int32_t, int64_t>>({prefix + "dot_product"});
486+
registerFunction<
487+
ParameterBinder<MapDotProductFunction, int64_t, int64_t>,
488+
int64_t,
489+
Map<int64_t, int64_t>,
490+
Map<int64_t, int64_t>>({prefix + "dot_product"});
491+
registerFunction<
492+
ParameterBinder<MapDotProductFunction, int32_t, double>,
493+
double,
494+
Map<int32_t, double>,
495+
Map<int32_t, double>>({prefix + "dot_product"});
496+
registerFunction<
497+
ParameterBinder<MapDotProductFunction, int64_t, double>,
498+
double,
499+
Map<int64_t, double>,
500+
Map<int64_t, double>>({prefix + "dot_product"});
501+
502+
// Register dot_product for maps with varchar keys
503+
registerFunction<
504+
ParameterBinder<MapDotProductFunction, Varchar, int64_t>,
505+
int64_t,
506+
Map<Varchar, int64_t>,
507+
Map<Varchar, int64_t>>({prefix + "dot_product"});
508+
registerFunction<
509+
ParameterBinder<MapDotProductFunction, Varchar, double>,
510+
double,
511+
Map<Varchar, double>,
512+
Map<Varchar, double>>({prefix + "dot_product"});
451513
}
452514
} // namespace facebook::velox::functions

velox/functions/prestosql/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ add_executable(
5151
ArraySumTest.cpp
5252
ArrayTrimTest.cpp
5353
ArrayUnionTest.cpp
54+
DotProductTest.cpp
5455
ArgTypesGeneratorTest.cpp
5556
BinaryFunctionsTest.cpp
5657
BingTileCastTest.cpp

0 commit comments

Comments
 (0)