diff --git a/apisix/consumer.lua b/apisix/consumer.lua index 0ec39d7190ee..1f0c0855e313 100644 --- a/apisix/consumer.lua +++ b/apisix/consumer.lua @@ -290,6 +290,8 @@ local function filter(consumer) return end + plugin.set_plugins_meta_parent(consumer.value.plugins, consumer) + -- We expect the id is the same as username. Fix up it here if it isn't. consumer.value.id = consumer.value.username end diff --git a/apisix/http/service.lua b/apisix/http/service.lua index 97b224d622c8..cdf6c122981c 100644 --- a/apisix/http/service.lua +++ b/apisix/http/service.lua @@ -16,6 +16,7 @@ -- local core = require("apisix.core") local apisix_upstream = require("apisix.upstream") +local plugin = require("apisix.plugin") local plugin_checker = require("apisix.plugin").plugin_checker local services local error = error @@ -46,6 +47,8 @@ local function filter(service) return end + plugin.set_plugins_meta_parent(service.value.plugins, service) + apisix_upstream.filter_upstream(service.value.upstream, service) core.log.info("filter service: ", core.json.delay_encode(service, true)) diff --git a/apisix/plugin.lua b/apisix/plugin.lua index 342bd9680e47..64b386d870bf 100644 --- a/apisix/plugin.lua +++ b/apisix/plugin.lua @@ -37,6 +37,9 @@ local error = error -- make linter happy to avoid error: getting the Lua global "load" -- luacheck: globals load, ignore lua_load local lua_load = load +local getmetatable = getmetatable +local setmetatable = setmetatable + local is_http = ngx.config.subsystem == "http" local local_plugins_hash = core.table.new(0, 32) local stream_local_plugins = core.table.new(32, 0) @@ -1161,6 +1164,28 @@ local function run_meta_pre_function(conf, api_ctx, name) end end + +function _M.set_plugins_meta_parent(plugins, parent) + if not plugins then + return + end + for _, plugin_conf in pairs(plugins) do + if not plugin_conf._meta then + plugin_conf._meta = {} + end + if not plugin_conf._meta.parent then + local mt_table = getmetatable(plugin_conf._meta) + if mt_table then + mt_table.parent = parent + else + plugin_conf._meta = setmetatable(plugin_conf._meta, + { __index = {parent = parent} }) + end + end + end +end + + function _M.run_plugin(phase, plugins, api_ctx) local plugin_run = false api_ctx = api_ctx or ngx.ctx.api_ctx diff --git a/apisix/plugins/ai-proxy-multi.lua b/apisix/plugins/ai-proxy-multi.lua index 7ac8bb206137..26ec943c253c 100644 --- a/apisix/plugins/ai-proxy-multi.lua +++ b/apisix/plugins/ai-proxy-multi.lua @@ -19,6 +19,12 @@ local core = require("apisix.core") local schema = require("apisix.plugins.ai-proxy.schema") local base = require("apisix.plugins.ai-proxy.base") local plugin = require("apisix.plugin") +local ipmatcher = require("resty.ipmatcher") +local events = require("apisix.events") + +local tonumber = tonumber +local pairs = pairs +local tostring = tostring local require = require local pcall = pcall @@ -26,6 +32,7 @@ local ipairs = ipairs local type = type local priority_balancer = require("apisix.balancer.priority") +local healthcheck local pickers = {} local lrucache_server_picker = core.lrucache.new({ @@ -119,16 +126,197 @@ local function transform_instances(new_instances, instance) end -local function create_server_picker(conf, ups_tab) +local function parse_domain_for_node(node) + local host = node.domain or node.host + if not ipmatcher.parse_ipv4(host) + and not ipmatcher.parse_ipv6(host) + then + node.domain = host + + local ip, err = core.resolver.parse_domain(host) + if ip then + node.host = ip + end + + if err then + core.log.error("dns resolver domain: ", host, " error: ", err) + end + end +end + + +local function resolve_endpoint(instance_conf) + local endpoint = core.table.try_read_attr(instance_conf, "override", "endpoint") + local scheme, host, port, _ = endpoint:match("^(https?)://([^:/]+):?(%d*)(/?.*)$") + if port == "" then + port = (scheme == "https") and "443" or "80" + end + local node = { + host = host, + port = tonumber(port), + scheme = scheme, + } + parse_domain_for_node(node) + return node +end + + +local function get_healthchecker_name(conf, instance_name) + return core.table.concat({plugin_name, tostring(conf), instance_name}, "#") +end + + +local function release_checkers(healthcheck_parent) + local ai_checkers = healthcheck_parent.ai_checkers + core.log.info("try to release ai_checkers: ", tostring(ai_checkers)) + for _, checker in pairs(ai_checkers) do + checker:clear() + checker:stop() + end +end + + +local function get_checkers_status_ver(checkers) + local status_ver_total = 0 + for _, checker in pairs(checkers) do + status_ver_total = status_ver_total + checker.status_ver + end + return status_ver_total +end + + +local function create_checkers(conf) + if healthcheck == nil then + healthcheck = require("resty.healthcheck") + end + + local healthcheck_parent = conf._meta.parent + if healthcheck_parent.ai_checkers and healthcheck_parent.ai_checker_conf == conf then + return healthcheck_parent.ai_checkers + end + + if conf.is_creating_ai_checkers then + core.log.info("another request is creating new checker") + return nil + end + conf.is_creating_ai_checkers = true + + local ai_checkers = core.table.new(0, #conf.instances) + + for _, ins in ipairs(conf.instances) do + if ins.checks then + core.log.info("create new healthcheck instance for ai_instance: ", ins.name, + " checks: ", core.json.delay_encode(ins.checks, true)) + local checker, err = healthcheck.new({ + name = get_healthchecker_name(conf, ins.name), + shm_name = "upstream-healthcheck", + checks = ins.checks, + events_module = events:get_healthcheck_events_modele(), + }) + if not checker then + core.log.error("failed to create healthcheck instance: ", err) + conf.is_creating_ai_checkers = nil + return nil + end + ai_checkers[ins.name] = checker + end + end + + if healthcheck_parent.ai_checkers then + local ok, err = pcall(core.config_util.cancel_clean_handler, healthcheck_parent, + healthcheck_parent.ai_checkers_idx, true) + if not ok then + core.log.error("cancel clean handler error: ", err) + end + end + + for _, ins in ipairs(conf.instances) do + local node = resolve_endpoint(ins) + local host = ins.checks and ins.checks.active and ins.checks.active.host + local port = ins.checks and ins.checks.active and ins.checks.active.port + local checker = ai_checkers[ins.name] + if checker then + local ok, err = checker:add_target(node.host, port or node.port, host) + if not ok then + core.log.error("failed to add new health check target: ", node.host, ":", + port or node.port, " err: ", err) + end + end + end + + healthcheck_parent.clean_handlers = healthcheck_parent.clean_handlers or {} + local check_idx, err = core.config_util.add_clean_handler(healthcheck_parent, release_checkers) + if not check_idx then + conf.is_creating_ai_checkers = nil + for _, checker in pairs(ai_checkers) do + checker:clear() + checker:stop() + end + core.log.error("failed to add clean handler, err:", + err, " healthcheck parent:", core.json.delay_encode(healthcheck_parent, true)) + + return nil + end + + healthcheck_parent.ai_checkers = ai_checkers + healthcheck_parent.ai_checkers_idx = check_idx + healthcheck_parent.ai_checker_conf = conf + + conf.is_creating_ai_checkers = nil + + return ai_checkers +end + + +local function fetch_health_instances(conf, checkers) + local instances = conf.instances + local new_instances = core.table.new(0, #instances) + if not checkers then + for _, ins in ipairs(conf.instances) do + transform_instances(new_instances, ins) + end + return new_instances + end + + for _, ins in ipairs(instances) do + local checker = checkers[ins.name] + if checker then + local host = ins.checks and ins.checks.active and ins.checks.active.host + local port = ins.checks and ins.checks.active and ins.checks.active.port + + local node = resolve_endpoint(ins) + local ok, err = checker:get_target_status(node.host, port or node.port, host) + if ok then + transform_instances(new_instances, ins) + elseif err then + core.log.error("failed to get health check target status, addr: ", + node.host, ":", port or node.port, ", host: ", host, ", err: ", err) + end + else + transform_instances(new_instances, ins) + end + end + + if core.table.nkeys(new_instances) == 0 then + core.log.warn("all upstream nodes is unhealthy, use default") + for _, ins in ipairs(instances) do + transform_instances(new_instances, ins) + end + end + + return new_instances +end + + +local function create_server_picker(conf, ups_tab, checkers) local picker = pickers[conf.balancer.algorithm] -- nil check if not picker then pickers[conf.balancer.algorithm] = require("apisix.balancer." .. conf.balancer.algorithm) picker = pickers[conf.balancer.algorithm] end - local new_instances = {} - for _, ins in ipairs(conf.instances) do - transform_instances(new_instances, ins) - end + + local new_instances = fetch_health_instances(conf, checkers) + core.log.info("fetch health instances: ", core.json.delay_encode(new_instances)) if #new_instances._priority_index > 1 then core.log.info("new instances: ", core.json.delay_encode(new_instances)) @@ -150,10 +338,18 @@ end local function pick_target(ctx, conf, ups_tab) + local checkers = #conf.instances > 1 and create_checkers(conf) + + local version = plugin.conf_version(conf) + if checkers then + local status_ver = get_checkers_status_ver(checkers) + version = version .. "#" .. status_ver + end + local server_picker = ctx.server_picker if not server_picker then - server_picker = lrucache_server_picker(ctx.matched_route.key, plugin.conf_version(conf), - create_server_picker, conf, ups_tab) + server_picker = lrucache_server_picker(ctx.matched_route.key, version, + create_server_picker, conf, ups_tab, checkers) end if not server_picker then return nil, nil, "failed to fetch server picker" diff --git a/apisix/plugins/ai-proxy/schema.lua b/apisix/plugins/ai-proxy/schema.lua index 1b9d07b1ccdc..007672f7d935 100644 --- a/apisix/plugins/ai-proxy/schema.lua +++ b/apisix/plugins/ai-proxy/schema.lua @@ -14,6 +14,8 @@ -- See the License for the specific language governing permissions and -- limitations under the License. -- +local schema_def = require("apisix.schema_def") + local _M = {} local auth_item_schema = { @@ -120,6 +122,13 @@ local ai_instance_schema = { }, }, }, + checks = { + type = "object", + properties = { + active = schema_def.health_checker_active, + }, + required = {"active"} + }, required = {"name", "provider", "auth", "weight"} }, } diff --git a/apisix/router.lua b/apisix/router.lua index 93b123e5b004..a9013a9ddf19 100644 --- a/apisix/router.lua +++ b/apisix/router.lua @@ -18,6 +18,8 @@ local require = require local http_route = require("apisix.http.route") local apisix_upstream = require("apisix.upstream") local core = require("apisix.core") +local set_plugins_meta_parent = require("apisix.plugin").set_plugins_meta_parent + local str_lower = string.lower local ipairs = ipairs @@ -33,6 +35,8 @@ local function filter(route) return end + set_plugins_meta_parent(route.value.plugins, route) + if route.value.host then route.value.host = str_lower(route.value.host) elseif route.value.hosts then diff --git a/apisix/schema_def.lua b/apisix/schema_def.lua index fd773990d2c1..4f21600c4f76 100644 --- a/apisix/schema_def.lua +++ b/apisix/schema_def.lua @@ -126,171 +126,179 @@ local timeout_def = { } -local health_checker = { +local health_checker_active = { type = "object", properties = { - active = { + type = { + type = "string", + enum = {"http", "https", "tcp"}, + default = "http" + }, + timeout = {type = "number", default = 1}, + concurrency = {type = "integer", default = 10}, + host = host_def, + port = { + type = "integer", + minimum = 1, + maximum = 65535 + }, + http_path = {type = "string", default = "/"}, + https_verify_certificate = {type = "boolean", default = true}, + healthy = { type = "object", properties = { - type = { - type = "string", - enum = {"http", "https", "tcp"}, - default = "http" + interval = {type = "integer", minimum = 1, default = 1}, + http_statuses = { + type = "array", + minItems = 1, + items = { + type = "integer", + minimum = 200, + maximum = 599 + }, + uniqueItems = true, + default = {200, 302} + }, + successes = { + type = "integer", + minimum = 1, + maximum = 254, + default = 2 + } + } + }, + unhealthy = { + type = "object", + properties = { + interval = {type = "integer", minimum = 1, default = 1}, + http_statuses = { + type = "array", + minItems = 1, + items = { + type = "integer", + minimum = 200, + maximum = 599 + }, + uniqueItems = true, + default = {429, 404, 500, 501, 502, 503, 504, 505} }, - timeout = {type = "number", default = 1}, - concurrency = {type = "integer", default = 10}, - host = host_def, - port = { + http_failures = { type = "integer", minimum = 1, - maximum = 65535 + maximum = 254, + default = 5 }, - http_path = {type = "string", default = "/"}, - https_verify_certificate = {type = "boolean", default = true}, - healthy = { - type = "object", - properties = { - interval = {type = "integer", minimum = 1, default = 1}, - http_statuses = { - type = "array", - minItems = 1, - items = { - type = "integer", - minimum = 200, - maximum = 599 - }, - uniqueItems = true, - default = {200, 302} - }, - successes = { - type = "integer", - minimum = 1, - maximum = 254, - default = 2 - } - } + tcp_failures = { + type = "integer", + minimum = 1, + maximum = 254, + default = 2 }, - unhealthy = { - type = "object", - properties = { - interval = {type = "integer", minimum = 1, default = 1}, - http_statuses = { - type = "array", - minItems = 1, - items = { - type = "integer", - minimum = 200, - maximum = 599 - }, - uniqueItems = true, - default = {429, 404, 500, 501, 502, 503, 504, 505} - }, - http_failures = { - type = "integer", - minimum = 1, - maximum = 254, - default = 5 - }, - tcp_failures = { - type = "integer", - minimum = 1, - maximum = 254, - default = 2 - }, - timeouts = { - type = "integer", - minimum = 1, - maximum = 254, - default = 3 - } - } + timeouts = { + type = "integer", + minimum = 1, + maximum = 254, + default = 3 + } + } + }, + req_headers = { + type = "array", + minItems = 1, + items = { + type = "string", + uniqueItems = true, + }, + } + } +} +_M.health_checker_active = health_checker_active + + +local health_checker_passive = { + type = "object", + properties = { + type = { + type = "string", + enum = {"http", "https", "tcp"}, + default = "http" + }, + healthy = { + type = "object", + properties = { + http_statuses = { + type = "array", + minItems = 1, + items = { + type = "integer", + minimum = 200, + maximum = 599, + }, + uniqueItems = true, + default = {200, 201, 202, 203, 204, 205, 206, 207, + 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308} }, - req_headers = { - type = "array", - minItems = 1, - items = { - type = "string", - uniqueItems = true, - }, + successes = { + type = "integer", + minimum = 0, + maximum = 254, + default = 5 } } }, - passive = { + unhealthy = { type = "object", properties = { - type = { - type = "string", - enum = {"http", "https", "tcp"}, - default = "http" + http_statuses = { + type = "array", + minItems = 1, + items = { + type = "integer", + minimum = 200, + maximum = 599, + }, + uniqueItems = true, + default = {429, 500, 503} }, - healthy = { - type = "object", - properties = { - http_statuses = { - type = "array", - minItems = 1, - items = { - type = "integer", - minimum = 200, - maximum = 599, - }, - uniqueItems = true, - default = {200, 201, 202, 203, 204, 205, 206, 207, - 208, 226, 300, 301, 302, 303, 304, 305, - 306, 307, 308} - }, - successes = { - type = "integer", - minimum = 0, - maximum = 254, - default = 5 - } - } + tcp_failures = { + type = "integer", + minimum = 0, + maximum = 254, + default = 2 }, - unhealthy = { - type = "object", - properties = { - http_statuses = { - type = "array", - minItems = 1, - items = { - type = "integer", - minimum = 200, - maximum = 599, - }, - uniqueItems = true, - default = {429, 500, 503} - }, - tcp_failures = { - type = "integer", - minimum = 0, - maximum = 254, - default = 2 - }, - timeouts = { - type = "integer", - minimum = 0, - maximum = 254, - default = 7 - }, - http_failures = { - type = "integer", - minimum = 0, - maximum = 254, - default = 5 - }, - } - } - }, + timeouts = { + type = "integer", + minimum = 0, + maximum = 254, + default = 7 + }, + http_failures = { + type = "integer", + minimum = 0, + maximum = 254, + default = 5 + }, + } } }, +} +_M.health_checker_passive = health_checker_passive + + +local health_checker = { + type = "object", + properties = { + active = health_checker_active, + passive = health_checker_passive, + }, anyOf = { {required = {"active"}}, {required = {"active", "passive"}}, }, additionalProperties = false, } - +_M.health_checker = health_checker local nodes_schema = { anyOf = { diff --git a/t/APISIX.pm b/t/APISIX.pm index f4b9b8055d4a..8bddd24a9596 100644 --- a/t/APISIX.pm +++ b/t/APISIX.pm @@ -585,6 +585,7 @@ _EOC_ lua_shared_dict xds-config 1m; lua_shared_dict xds-config-version 1m; lua_shared_dict cas_sessions 10m; + lua_shared_dict test 5m; proxy_ssl_name \$upstream_host; proxy_ssl_server_name on; diff --git a/t/control/services.t b/t/control/services.t index 0003bcc9d1aa..257e6da56c45 100644 --- a/t/control/services.t +++ b/t/control/services.t @@ -157,7 +157,7 @@ services: } } --- response_body -{"id":"5","plugins":{"limit-count":{"allow_degradation":false,"count":2,"key":"remote_addr","key_type":"var","policy":"local","rejected_code":503,"show_limit_quota_header":true,"time_window":60}},"upstream":{"hash_on":"vars","nodes":[{"host":"127.0.0.1","port":1980,"weight":1}],"pass_host":"pass","scheme":"http","type":"roundrobin"}} +{"id":"5","plugins":{"limit-count":{"_meta":{},"allow_degradation":false,"count":2,"key":"remote_addr","key_type":"var","policy":"local","rejected_code":503,"show_limit_quota_header":true,"sync_interval":-1,"time_window":60}},"upstream":{"hash_on":"vars","nodes":[{"host":"127.0.0.1","port":1980,"weight":1}],"pass_host":"pass","scheme":"http","type":"roundrobin"}} diff --git a/t/plugin/ai-proxy-multi3.t b/t/plugin/ai-proxy-multi3.t new file mode 100644 index 000000000000..3ace9d798018 --- /dev/null +++ b/t/plugin/ai-proxy-multi3.t @@ -0,0 +1,877 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 16724; + + default_type 'application/json'; + + location /anything { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body = ngx.req.get_body_data() + + if body ~= "SELECT * FROM STUDENTS" then + ngx.status = 503 + ngx.say("passthrough doesn't work") + return + end + ngx.say('{"foo", "bar"}') + } + } + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local test_type = ngx.req.get_headers()["test-type"] + if test_type == "options" then + if body.foo == "bar" then + ngx.status = 200 + ngx.say("options works") + else + ngx.status = 500 + ngx.say("model options feature doesn't work") + end + return + end + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + if body.messages[1].content == "write an SQL query to get all rows from student table" then + ngx.print("SELECT * FROM STUDENTS") + return + end + + ngx.status = 200 + ngx.say(string.format([[ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { "content": "1 + 1 = 2.", "role": "assistant" } + } + ], + "created": 1723780938, + "id": "chatcmpl-9wiSIg5LYrrpxwsr2PubSQnbtod1P", + "model": "%s", + "object": "chat.completion", + "system_fingerprint": "fp_abc28019ad", + "usage": { "completion_tokens": 5, "prompt_tokens": 8, "total_tokens": 10 } +} + ]], body.model)) + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + + location /random { + content_by_lua_block { + ngx.say("path override works") + } + } + + location ~ ^/status.* { + content_by_lua_block { + local test_dict = ngx.shared["test"] + local uri = ngx.var.uri + local total_key = uri .. "#total" + local count_key = uri .. "#count" + local total = test_dict:get(total_key) + if not total then + return + end + + local count = test_dict:incr(count_key, 1, 0) + ngx.log(ngx.INFO, "uri: ", uri, " total: ", total, " count: ", count) + if count < total then + return + end + ngx.status = 500 + ngx.say("error") + } + } + + location /error { + content_by_lua_block { + ngx.status = 500 + ngx.say("error") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: set route, only one instance has checker +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/ai", + "plugins": { + "ai-proxy-multi": { + "fallback_strategy": "instance_health_and_rate_limiting", + "instances": [ + { + "name": "openai-gpt4", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + "checks": { + "active": { + "timeout": 5, + "http_path": "/status/gpt4", + "host": "foo.com", + "healthy": { + "interval": 1, + "successes": 1 + }, + "unhealthy": { + "interval": 1, + "http_failures": 1 + }, + "req_headers": ["User-Agent: curl/7.29.0"] + } + } + }, + { + "name": "openai-gpt3", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-3" + }, + "override": { + "endpoint": "http://localhost:16724" + } + } + ], + "ssl_verify": false + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: once instance changes from unhealthy to healthy +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local core = require("apisix.core") + local test_dict = ngx.shared["test"] + + local send_request = function() + local code, _, body = t("/ai", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + assert(code == 200, "request should be successful") + return body + end + + -- set the instance to unhealthy + test_dict:set("/status/gpt4#total", 0) + -- trigger the health check + send_request() + ngx.sleep(1) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + end + + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] <= 2, "gpt-4 should be unhealthy") + assert(instances_count["gpt-3"] >= 8, "gpt-3 should be healthy") + + -- set the instance to healthy + test_dict:set("/status/gpt4#total", 30) + ngx.sleep(1) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + + local v = instances_count["gpt-4"] - instances_count["gpt-3"] + assert(v <= 2, "difference between gpt-4 and gpt-3 should be less than 2") + ngx.say("passed") + } + } +--- timeout: 10 +--- response_body +passed + + + +=== TEST 3: set service, only one instance has checker +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/services/1', + ngx.HTTP_PUT, + [[{ + "plugins": { + "ai-proxy-multi": { + "fallback_strategy": "instance_health_and_rate_limiting", + "instances": [ + { + "name": "openai-gpt4", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + "checks": { + "active": { + "timeout": 5, + "http_path": "/status/gpt4", + "host": "foo.com", + "healthy": { + "interval": 1, + "successes": 1 + }, + "unhealthy": { + "interval": 1, + "http_failures": 1 + }, + "req_headers": ["User-Agent: curl/7.29.0"] + } + } + }, + { + "name": "openai-gpt3", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-3" + }, + "override": { + "endpoint": "http://localhost:16724" + } + } + ], + "ssl_verify": false + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: set route 1 related to service 1 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/ai", + "service_id": 1 + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 5: instance changes from unhealthy to healthy +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local core = require("apisix.core") + local test_dict = ngx.shared["test"] + + local send_request = function() + local code, _, body = t("/ai", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + assert(code == 200, "request should be successful") + return body + end + + -- set the instance to unhealthy + test_dict:set("/status/gpt4#total", 0) + -- trigger the health check + send_request() + ngx.sleep(1.2) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + end + + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] <= 2, "gpt-4 should be unhealthy") + assert(instances_count["gpt-3"] >= 8, "gpt-3 should be healthy") + + -- set the instance to healthy + test_dict:set("/status/gpt4#total", 30) + ngx.sleep(1.2) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + + local diff = instances_count["gpt-4"] - instances_count["gpt-3"] + assert(diff <= 2, "difference between gpt-4 and gpt-3 should be less than 2") + ngx.say("passed") + } + } +--- timeout: 10 +--- response_body +passed + + + +=== TEST 6: set route, two instances have checker +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local checks_tmp = [[ + "checks": { + "active": { + "timeout": 5, + "http_path": "/status/%s", + "host": "foo.com", + "healthy": { + "interval": 1, + "successes": 1 + }, + "unhealthy": { + "interval": 1, + "http_failures": 1 + }, + "req_headers": ["User-Agent: curl/7.29.0"] + } + } + ]] + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/ai", + "plugins": { + "ai-proxy-multi": { + "fallback_strategy": "instance_health_and_rate_limiting", + "instances": [ + { + "name": "openai-gpt4", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + ]] .. string.format(checks_tmp, "gpt4").. [[ + }, + { + "name": "openai-gpt3", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-3" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + ]] .. string.format(checks_tmp, "gpt3") .. [[ + } + ], + "ssl_verify": false + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 7: healthy conversion of two instances +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local core = require("apisix.core") + local test_dict = ngx.shared["test"] + + local send_request = function() + local code, _, body = t("/ai", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + assert(code == 200, "request should be successful") + return body + end + + -- set the gpt4 instance to unhealthy + -- set the gpt3 instance to healthy + test_dict:set("/status/gpt4#total", 0) + test_dict:set("/status/gpt3#total", 50) + -- trigger the health check + send_request() + ngx.sleep(1.2) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + end + + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] <= 2, "gpt-4 should be unhealthy") + assert(instances_count["gpt-3"] >= 8, "gpt-3 should be healthy") + + -- set the gpt4 instance to healthy + -- set the gpt3 instance to unhealthy + test_dict:set("/status/gpt4#total", 50) + test_dict:set("/status/gpt3#total", 0) + ngx.sleep(1.2) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + + assert(instances_count["gpt-4"] >= 8, "gpt-4 should be healthy") + assert(instances_count["gpt-3"] <= 2, "gpt-3 should be unhealthy") + ngx.say("passed") + } + } +--- timeout: 10 +--- response_body +passed + + + +=== TEST 8: set route, two instances have checker +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local checks_tmp = [[ + "checks": { + "active": { + "timeout": 5, + "http_path": "/status/%s", + "host": "foo.com", + "healthy": { + "interval": 1, + "successes": 1 + }, + "unhealthy": { + "interval": 1, + "http_failures": 1 + }, + "req_headers": ["User-Agent: curl/7.29.0"] + } + } + ]] + local code, body = t('/apisix/admin/services/1', + ngx.HTTP_PUT, + [[{ + "plugins": { + "ai-proxy-multi": { + "fallback_strategy": "instance_health_and_rate_limiting", + "instances": [ + { + "name": "openai-gpt4", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + ]] .. string.format(checks_tmp, "gpt4").. [[ + }, + { + "name": "openai-gpt3", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-3" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + ]] .. string.format(checks_tmp, "gpt3") .. [[ + } + ], + "ssl_verify": false + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 9: set route 1 related to service 1 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/ai", + "service_id": 1 + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 10: healthy conversion of two instances +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local core = require("apisix.core") + local test_dict = ngx.shared["test"] + + local send_request = function() + local code, _, body = t("/ai", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + assert(code == 200, "request should be successful") + return body + end + + -- set the gpt4 instance to unhealthy + -- set the gpt3 instance to healthy + test_dict:set("/status/gpt4#total", 0) + test_dict:set("/status/gpt3#total", 50) + -- trigger the health check + send_request() + ngx.sleep(1.2) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + end + + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] <= 2, "gpt-4 should be unhealthy") + assert(instances_count["gpt-3"] >= 8, "gpt-3 should be healthy") + + -- set the gpt4 instance to healthy + -- set the gpt3 instance to unhealthy + test_dict:set("/status/gpt4#total", 50) + test_dict:set("/status/gpt3#total", 0) + ngx.sleep(1.2) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + + assert(instances_count["gpt-4"] >= 8, "gpt-4 should be healthy") + assert(instances_count["gpt-3"] <= 2, "gpt-3 should be unhealthy") + ngx.say("passed") + } + } +--- timeout: 10 +--- response_body +passed