Skip to content

Commit f211841

Browse files
committed
Included a thirdparty library with LLM predictor
1 parent a434b79 commit f211841

File tree

19 files changed

+19664
-19491
lines changed

19 files changed

+19664
-19491
lines changed

CMakeLists.txt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ option(BUILD_BENCHMARKS "Enable building of the benchmark suite." FALSE)
373373
option(BUILD_TPCE "Enable building of the TPC-E tool." FALSE)
374374
option(DISABLE_BUILTIN_EXTENSIONS "Disable linking extensions." FALSE)
375375
option(ENABLE_PREDICT "Build the ML Operator with DuckDB" FALSE)
376-
option(USE_TORCH "Build the ML Operator with DuckDB using the libtorch predictor" FALSE)
376+
option(PREDICTOR_IMPL "Build the ML Operator with DuckDB using specified predictor (torchscript, onnx, llm_api)" "")
377377
option(BUILD_PYTHON "Build the DuckDB Python extension" FALSE)
378378
option(USER_SPACE "Build the DuckDB Python in the user space" FALSE)
379379
option(FORCE_QUERY_LOG "If enabled, all queries will be logged to the specified path" OFF)
@@ -1214,19 +1214,24 @@ endif()
12141214

12151215
if (ENABLE_PREDICT)
12161216
add_definitions(-DENABLE_PREDICT)
1217-
if (USE_TORCH)
1217+
if (PREDICTOR_IMPL STREQUAL "torchscript")
12181218
if (NOT DEFINED ENV{TORCH_INSTALL_PREFIX})
12191219
message( FATAL_ERROR "TORCH_INSTALL_PREFIX environment variable is not defined with USE_TORCH=1 option, CMake will exit." )
12201220
endif()
12211221
set(LIBTORCH_PATH "$ENV{TORCH_INSTALL_PREFIX}")
12221222
set(LIBTORCH_INCLUDE_DIRS "${LIBTORCH_PATH}/include")
12231223
message(STATUS "Torch Include Libs: ${LIBTORCH_INCLUDE_DIRS}")
1224-
add_definitions(-DUSE_TORCH)
1225-
else ()
1224+
add_definitions(-DPREDICTOR_IMPL=1)
1225+
endif ()
1226+
if (PREDICTOR_IMPL STREQUAL "onnx")
12261227
if (NOT DEFINED ENV{ONNX_INSTALL_PREFIX})
12271228
message( FATAL_ERROR "ONNX_INSTALL_PREFIX environment variable is not defined with USE_TORCH=0 option, CMake will exit." )
12281229
endif()
12291230
find_path(ONNX_RUNTIME_SESSION_INCLUDE_DIRS onnxruntime_cxx_api.h HINTS "$ENV{ONNX_INSTALL_PREFIX}/include")
1231+
add_definitions(-DPREDICTOR_IMPL=2)
1232+
endif()
1233+
if (PREDICTOR_IMPL STREQUAL "llm_api")
1234+
add_definitions(-DPREDICTOR_IMPL=3)
12301235
endif()
12311236
endif()
12321237

Makefile

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,7 @@ ifeq (${BUILD_PYTHON}, 1)
147147
CMAKE_VARS:=${CMAKE_VARS} -DBUILD_PYTHON=1 -DDUCKDB_EXTENSION_CONFIGS="tools/pythonpkg/duckdb_extension_config.cmake"
148148
endif
149149
ifeq (${ENABLE_PREDICT}, 1)
150-
CMAKE_VARS:=${CMAKE_VARS} -DENABLE_PREDICT=1
151-
endif
152-
ifeq (${USE_TORCH}, 1)
153-
CMAKE_VARS:=${CMAKE_VARS} -DUSE_TORCH=1
150+
CMAKE_VARS:=${CMAKE_VARS} -DENABLE_PREDICT=1 -DPREDICTOR_IMPL="${PREDICTOR_IMPL}"
154151
endif
155152
ifeq (${PYTHON_USER_SPACE}, 1)
156153
CMAKE_VARS:=${CMAKE_VARS} -DUSER_SPACE=1

src/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,12 @@ else()
111111
duckdb_yyjson)
112112

113113
if(ENABLE_PREDICT)
114-
if(USE_TORCH)
114+
if (${PREDICTOR_IMPL} STREQUAL "torchscript")
115115
set(DUCKDB_LINK_LIBS ${DUCKDB_LINK_LIBS} duckdb_torch)
116-
else()
116+
elseif (${PREDICTOR_IMPL} STREQUAL "onnx")
117117
set(DUCKDB_LINK_LIBS ${DUCKDB_LINK_LIBS} duckdb_onnx)
118+
elseif (${PREDICTOR_IMPL} STREQUAL "llm_api")
119+
set(DUCKDB_LINK_LIBS ${DUCKDB_LINK_LIBS} duckdb_llm_api)
118120
endif()
119121
endif()
120122

src/common/enum_util.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4614,6 +4614,8 @@ const char* EnumUtil::ToChars<ModelType>(ModelType value) {
46144614
return "LLM";
46154615
case ModelType::GNN:
46164616
return "GNN";
4617+
case ModelType::LLM_API:
4618+
return "LLM_API";
46174619
default:
46184620
throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value));
46194621
}
@@ -4630,6 +4632,9 @@ ModelType EnumUtil::FromString<ModelType>(const char *value) {
46304632
if (StringUtil::Equals(value, "GNN")) {
46314633
return ModelType::GNN;
46324634
}
4635+
if (StringUtil::Equals(value, "LLM_API")) {
4636+
return ModelType::LLM_API;
4637+
}
46334638
throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value));
46344639
}
46354640

src/execution/CMakeLists.txt

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
if(ENABLE_PREDICT)
2-
if(USE_TORCH)
2+
if (PREDICTOR_IMPL STREQUAL "torchscript")
3+
message(
4+
STATUS "Added TorchScript include directories")
35
include_directories(${LIBTORCH_INCLUDE_DIRS})
46
include_directories(../../third_party/predictors/torchscript/include)
5-
else()
7+
endif ()
8+
if (PREDICTOR_IMPL STREQUAL "onnx")
69
message(
7-
STATUS "ORT include lib in src -> ${ONNX_RUNTIME_SESSION_INCLUDE_DIRS}")
10+
STATUS "Added ORT include lib in src -> ${ONNX_RUNTIME_SESSION_INCLUDE_DIRS}")
811
include_directories(${ONNX_RUNTIME_SESSION_INCLUDE_DIRS})
912
include_directories(../../third_party/predictors/onnx/include)
1013
endif()
14+
if (PREDICTOR_IMPL STREQUAL "llm_api")
15+
message(
16+
STATUS "Added LLM API include directories")
17+
include_directories(../../third_party/predictors/llm_api/include)
18+
endif()
1119
endif()
1220

1321
add_subdirectory(expression_executor)

src/execution/operator/projection/physical_gnn_predict.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <iostream>
1010
#include <map>
1111

12-
#if ENABLE_PREDICT
12+
#if defined(ENABLE_PREDICT) && PREDICTOR_IMPL == 2
1313
#include "duckdb_onnx.hpp"
1414
#endif
1515

@@ -67,7 +67,7 @@ PhysicalGNNPredict::PhysicalGNNPredict(vector<LogicalType> types_p, idx_t estima
6767
}
6868

6969
unique_ptr<Predictor> PhysicalGNNPredict::InitPredictor() const {
70-
#if defined(ENABLE_PREDICT)
70+
#if defined(ENABLE_PREDICT) && PREDICTOR_IMPL == 2
7171
return make_uniq<ONNXPredictor>();
7272
#else
7373
return nullptr;

src/execution/operator/projection/physical_predict.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
#include <iostream>
44
#include <map>
55

6-
#if defined(ENABLE_PREDICT) && defined(USE_TORCH)
6+
#if defined(ENABLE_PREDICT) && PREDICTOR_IMPL == 1
77
#include "duckdb_torch.hpp"
8-
#elif ENABLE_PREDICT
8+
#elif defined(ENABLE_PREDICT) && PREDICTOR_IMPL == 2
99
#include "duckdb_onnx.hpp"
10+
#elif defined(ENABLE_PREDICT) && PREDICTOR_IMPL == 3
11+
#include "duckdb_llm_api.hpp"
1012
#endif
1113

1214
#define CHUNK_PRED 1
@@ -45,10 +47,12 @@ PhysicalPredict::PhysicalPredict(vector<LogicalType> types_p, unique_ptr<Physica
4547
}
4648

4749
unique_ptr<Predictor> PhysicalPredict::InitPredictor() const {
48-
#if defined(ENABLE_PREDICT) && USE_TORCH
50+
#if defined(ENABLE_PREDICT) && PREDICTOR_IMPL == 1
4951
return make_uniq<TorchPredictor>();
50-
#elif defined(ENABLE_PREDICT)
52+
#elif defined(ENABLE_PREDICT) && PREDICTOR_IMPL == 2
5153
return make_uniq<ONNXPredictor>();
54+
#elif defined(ENABLE_PREDICT) && PREDICTOR_IMPL == 3
55+
return make_uniq<LlmApiPredictor>();
5256
#else
5357
return nullptr;
5458
#endif
@@ -83,6 +87,9 @@ OperatorResultType PhysicalPredict::Execute(ExecutionContext &context, DataChunk
8387
} else if (predictor.task == PREDICT_LLM_TASK) {
8488
predictor.PredictLMChunk(input, predictions, (int)input.size(), this->input_mask, (int)result_set_types.size(),
8589
state.stats);
90+
} else if (predictor.task == PREDICT_LLM_API_TASK) {
91+
predictor.PredictChunk(input, predictions, (int)input.size(), this->input_mask, (int)result_set_types.size(),
92+
state.stats);
8693
}
8794
#elif VEC_PRED
8895
std::vector<float> inputs;

src/execution/physical_plan/plan_predict.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ namespace duckdb {
99
unique_ptr<PhysicalOperator> PhysicalPlanGenerator::CreatePlan(LogicalPredict &op) {
1010
switch (op.bound_predict.model_type) {
1111
case ModelType::TABULAR:
12-
case ModelType::LLM: {
12+
case ModelType::LLM:
13+
case ModelType::LLM_API: {
1314
D_ASSERT(op.children.size() == 1);
1415
auto child_plan = CreatePlan(*op.children[0]);
1516
auto predict = make_uniq<PhysicalPredict>(std::move(op.types), std::move(child_plan));

src/include/duckdb/common/enums/model_type.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ namespace duckdb {
1717
enum class ModelType : uint8_t {
1818
TABULAR,
1919
LLM,
20-
GNN
20+
GNN,
21+
LLM_API,
2122
};
2223

2324
} // namespace duckdb

src/include/duckdb/execution/operator/projection/physical_predict.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
#include "duckdb/planner/expression.hpp"
1616

1717
namespace duckdb {
18-
typedef enum PredictorTask { PREDICT_TABULAR_TASK = 0, PREDICT_LLM_TASK = 1, PREDICT_GNN_TASK = 2 } PredictorTask;
18+
typedef enum PredictorTask {
19+
PREDICT_TABULAR_TASK = 0,
20+
PREDICT_LLM_TASK = 1,
21+
PREDICT_GNN_TASK = 2,
22+
PREDICT_LLM_API_TASK = 3
23+
} PredictorTask;
1924

2025
struct PredictStats {
2126
long load;

0 commit comments

Comments
 (0)