Skip to content

Commit 55902fb

Browse files
committed
API key manager based on request_id
1 parent 2bb3831 commit 55902fb

File tree

8 files changed

+595
-13
lines changed

8 files changed

+595
-13
lines changed

src/EasyContext.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ include("stateful_transformation/StatefulTransformators.jl")
3737

3838
include("utils/AIGenerateFallback.jl")
3939
include("utils/APIKeyRotation.jl")
40+
include("utils/APIKeyManager.jl")
41+
4042
include("rerankers/ChunkBatchers.jl")
4143
include("rerankers/rerank_prompts.jl")
4244
include("rerankers/ReduceGPTReranker.jl")

src/contexts/CTX_workspace.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function init_workspace_context(project_paths::Vector{<:AbstractString};
4242
)
4343
end
4444

45-
function process_workspace_context(workspace_context::WorkspaceCTX, embedder_query; rerank_query=embedder_query, enabled=true, age_tracker=nothing, extractor=nothing, io::Union{IO, Nothing}=nothing, query_images::Union{AbstractVector{<:AbstractString}, Nothing}=nothing, )
45+
function process_workspace_context(workspace_context::WorkspaceCTX, embedder_query; rerank_query=embedder_query, enabled=true, age_tracker=nothing, extractor=nothing, io::Union{IO, Nothing}=nothing, query_images::Union{AbstractVector{<:AbstractString}, Nothing}=nothing, request_id=nothing)
4646
!enabled || isempty(workspace_context.workspace) && return ("", nothing, nothing, nothing)
4747

4848
start_time = time()
@@ -51,7 +51,7 @@ function process_workspace_context(workspace_context::WorkspaceCTX, embedder_que
5151
isempty(file_chunks) && return ("", nothing, nothing, nothing)
5252

5353
cost_tracker = Threads.Atomic{Float64}(0.0)
54-
file_chunks_reranked = search(workspace_context.rag_pipeline, file_chunks, embedder_query; rerank_query, cost_tracker, query_images)
54+
file_chunks_reranked = search(workspace_context.rag_pipeline, file_chunks, embedder_query; rerank_query, cost_tracker, query_images, request_id)
5555
merged_file_chunks = merge!(workspace_context.tracker_context, file_chunks_reranked)
5656

5757
!isnothing(extractor) && update_changes_from_extractor!(workspace_context.changes_tracker, extractor)

src/embedders/TopK.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ TopK(embedder::RAGTools.AbstractEmbedder; top_k::Int=DEFAULT_TOP_K) = TopK(embed
3737

3838
# Main interface
3939
function search(topn::TopK, chunks::AbstractVector{T}, query::AbstractString;
40-
cost_tracker = Threads.Atomic{Float64}(0.0), query_images::Union{Nothing,Vector{String}}=nothing) where T
40+
cost_tracker = Threads.Atomic{Float64}(0.0), query_images::Union{Nothing,Vector{String}}=nothing, request_id=nothing) where T
4141
scores = get_score(topn.embedder, chunks, query; cost_tracker, query_images)
4242
result = topN(scores, chunks, topn.top_k)
4343
return result
@@ -53,7 +53,7 @@ export humanize
5353
struct NullTopK <: AbstractRAGPipeline end
5454

5555
# Return no retrieved chunks
56-
function search(::NullTopK, chunks::Vector{T}, _query::AbstractString; query_images=nothing, cost_tracker=Threads.Atomic{Float64}(0.0)) where T
56+
function search(::NullTopK, chunks::Vector{T}, _query::AbstractString; query_images=nothing, cost_tracker=Threads.Atomic{Float64}(0.0), request_id=nothing) where T
5757
T[]
5858
end
5959

src/rag/AdvancedRAG.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ function search(method::TwoLayerRAG, chunks::Vector{T}, query::AbstractString;
2020
rerank_query::Union{AbstractString, Nothing}=nothing,
2121
cost_tracker = Threads.Atomic{Float64}(0.0),
2222
query_images::Union{AbstractVector{<:AbstractString}, Nothing}=nothing,
23+
request_id=nothing,
2324
) where T
2425

2526
rerank_query = rerank_query === nothing ? query : rerank_query
26-
results = search(method.topK, chunks, query; query_images, cost_tracker)
27-
return rerank(method.reranker, results, rerank_query; cost_tracker, query_images)
27+
results = search(method.topK, chunks, query; query_images, cost_tracker, request_id)
28+
return rerank(method.reranker, results, rerank_query; cost_tracker, query_images, request_id)
2829
end
2930

3031
@kwdef mutable struct TwoLayerRAGWithTimings <: AbstractRAGPipeline
@@ -38,16 +39,17 @@ function search(method::TwoLayerRAGWithTimings, chunks::Vector{T}, query::Abstra
3839
rerank_query::Union{AbstractString, Nothing}=nothing,
3940
cost_tracker = Threads.Atomic{Float64}(0.0),
4041
query_images::Union{AbstractVector{<:AbstractString}, Nothing}=nothing,
42+
request_id=nothing,
4143
) where T
4244

4345
rerank_query = rerank_query === nothing ? query : rerank_query
4446

4547
start_search = time()
46-
results = search(method.topK, chunks, query; query_images, cost_tracker)
48+
results = search(method.topK, chunks, query; query_images, cost_tracker, request_id)
4749
push!(method.search_times, time() - start_search)
4850

4951
start_rerank = time()
50-
final_results = rerank(method.reranker, results, rerank_query; cost_tracker, query_images)
52+
final_results = rerank(method.reranker, results, rerank_query; cost_tracker, query_images, request_id)
5153
push!(method.rerank_times, time() - start_rerank)
5254

5355
return final_results

src/rerankers/ReduceGPTReranker.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ function rerank(
2525
cost_tracker = Threads.Atomic{Float64}(0.0),
2626
query_images::Union{AbstractVector{<:AbstractString}, Nothing}=nothing,
2727
verbose::Int = reranker.verbose,
28+
request_id=nothing,
2829
) where T
2930
# Initialize AIGenerateFallback with model preferences based on model type
3031
ai_manager = AIGenerateFallback(models=reranker.model)
@@ -51,7 +52,8 @@ function rerank(
5152

5253
response = try_generate(ai_manager, prompt;
5354
api_kwargs=(; temperature, top_p=0.1),
54-
verbose=false
55+
verbose=false,
56+
request_id
5557
)
5658
rankings = extract_ranking(response.content; verbose=verbose)
5759

@@ -157,7 +159,7 @@ end
157159
struct NullReranker <: AbstractReranker end
158160

159161
# Pass-through (will be empty anyway)
160-
function rerank(::NullReranker, results, _rerank_query; cost_tracker=Threads.Atomic{Float64}(0.0), query_images=nothing)
162+
function rerank(::NullReranker, results, _rerank_query; cost_tracker=Threads.Atomic{Float64}(0.0), query_images=nothing, request_id=nothing)
161163
results
162164
end
163165

src/utils/APIKeyManager.jl

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
export APIKeyManager, get_api_key_for_model, StringApiKey
2+
3+
using PromptingTools
4+
using JSON3
5+
using JLD2
6+
using LLMRateLimiters: CharCountDivTwo, RateLimiterTPM
7+
8+
const DATA_DIR = joinpath(dirname(@__DIR__), "..", "data")
9+
const STATS_FILE = joinpath(DATA_DIR, "credentials_stats.jld2")
10+
const STATS_LOCK = ReentrantLock()
11+
12+
"""
13+
StringApiKey
14+
15+
Represents an API key with proper sliding window rate limiting.
16+
"""
17+
mutable struct StringApiKey
18+
key::String
19+
schema_name::String
20+
rate_limiter::RateLimiterTPM
21+
last_save_time::Float64
22+
save_threshold::Int
23+
tokens_since_save::Int
24+
25+
function StringApiKey(key::String, schema_name::String = "OpenAISchema", max_tokens_per_minute::Int = 1_000_000)
26+
rate_limiter = RateLimiterTPM(
27+
max_tokens = max_tokens_per_minute,
28+
time_window = 60.0,
29+
estimation_method = CharCountDivTwo
30+
)
31+
new(key, schema_name, rate_limiter, time(), 10, 0)
32+
end
33+
end
34+
35+
"""
36+
APIKeyManager
37+
38+
Manages API keys with rate limiting and request routing.
39+
"""
40+
mutable struct APIKeyManager
41+
schema_to_api_keys::Dict{Type{<:AbstractPromptSchema}, Vector{StringApiKey}}
42+
request_affinity::Dict{String, Tuple{String, Float64}}
43+
affinity_window::Float64
44+
45+
function APIKeyManager(affinity_window::Float64 = 300.0)
46+
new(Dict{Type{<:AbstractPromptSchema}, Vector{StringApiKey}}(),
47+
Dict{String,Tuple{String,Float64}}(), affinity_window)
48+
end
49+
end
50+
51+
# Global instance
52+
const GLOBAL_API_KEY_MANAGER = APIKeyManager()
53+
54+
"""
55+
save_stats_to_file!(api_key::StringApiKey)
56+
57+
Save API key statistics to JLD2 file asynchronously with partial updates.
58+
"""
59+
function save_stats_to_file!(api_key::StringApiKey)
60+
@async_showerr begin
61+
lock(STATS_LOCK) do
62+
!isdir(DATA_DIR) && mkpath(DATA_DIR)
63+
key_hash = string(hash(api_key.key)) # Use hash for privacy
64+
jldopen(STATS_FILE, "a+") do file # a+ is OK
65+
haskey(file, key_hash) && delete!(file, key_hash) # required to overwrite
66+
file[key_hash] = Dict(
67+
"schema_name" => api_key.schema_name,
68+
"tokens_used_last_minute" => get_current_usage(api_key),
69+
"last_save_time" => api_key.last_save_time
70+
)
71+
end
72+
end
73+
end
74+
end
75+
76+
"""
77+
update_usage!(api_key::StringApiKey, tokens::Int)
78+
79+
Update token usage using proper rate limiter and save to file periodically.
80+
"""
81+
function update_usage!(api_key::StringApiKey, tokens::Int)
82+
LLMRateLimiters.add_tokens!(api_key.rate_limiter, tokens)
83+
api_key.tokens_since_save += tokens
84+
85+
if api_key.tokens_since_save >= api_key.save_threshold
86+
api_key.tokens_since_save = 0
87+
api_key.last_save_time = time()
88+
save_stats_to_file!(api_key)
89+
end
90+
end
91+
92+
"""
93+
get_current_usage(api_key::StringApiKey) -> Int
94+
95+
Get current token usage in the sliding window.
96+
"""
97+
get_current_usage(api_key::StringApiKey) = LLMRateLimiters.current_usage(api_key.rate_limiter)
98+
99+
"""
100+
can_handle_tokens(api_key::StringApiKey, tokens::Int) -> Bool
101+
102+
Check if the API key can handle the requested number of tokens.
103+
"""
104+
can_handle_tokens(api_key::StringApiKey, tokens::Int) = LLMRateLimiters.can_add_tokens(api_key.rate_limiter, tokens)
105+
106+
"""
107+
add_api_keys!(manager::APIKeyManager, schema_type::Type{<:AbstractPromptSchema}, keys::Vector{String}, max_tokens_per_minute::Int = 1_000_000)
108+
109+
Add API keys for a specific schema type with rate limiting.
110+
"""
111+
function add_api_keys!(manager::APIKeyManager, schema_type::Type{<:AbstractPromptSchema}, keys::Vector{String}, max_tokens_per_minute::Int = 1_000_000)
112+
if !haskey(manager.schema_to_api_keys, schema_type)
113+
manager.schema_to_api_keys[schema_type] = StringApiKey[]
114+
end
115+
schema_name = string(nameof(schema_type))
116+
append!(manager.schema_to_api_keys[schema_type], [StringApiKey(key, schema_name, max_tokens_per_minute) for key in keys])
117+
end
118+
119+
"""
120+
collect_env_keys(base_env_var::String) -> Vector{String}
121+
122+
Collect all API keys for a given base environment variable (base + numbered variants).
123+
"""
124+
function collect_env_keys(base_env_var::String)
125+
keys = String[]
126+
127+
# Check base key
128+
if haskey(ENV, base_env_var) && !isempty(ENV[base_env_var])
129+
push!(keys, ENV[base_env_var])
130+
end
131+
132+
# Check numbered keys (KEY_2, KEY_3, etc.)
133+
for i in 2:100
134+
env_var = "$(base_env_var)_$i"
135+
if haskey(ENV, env_var) && !isempty(ENV[env_var])
136+
push!(keys, ENV[env_var])
137+
else
138+
break # Stop at first missing numbered key
139+
end
140+
end
141+
142+
return keys
143+
end
144+
145+
"""
146+
find_api_key_for_request(manager::APIKeyManager, schema_type::Type{<:AbstractPromptSchema},
147+
request_id::Union{String, Nothing}, estimated_tokens::Int)
148+
149+
Find API key with lowest current usage (with sticky routing preference).
150+
"""
151+
function find_api_key_for_request(manager::APIKeyManager, schema_type::Type{<:AbstractPromptSchema},
152+
request_id::Union{String, Nothing}, estimated_tokens::Int)
153+
!haskey(manager.schema_to_api_keys, schema_type) && return nothing
154+
155+
api_keys = manager.schema_to_api_keys[schema_type]
156+
isempty(api_keys) && return nothing
157+
158+
# 1) Sticky routing if possible
159+
if !isnothing(request_id) && haskey(manager.request_affinity, request_id)
160+
key_str, last_t = manager.request_affinity[request_id]
161+
if time() - last_t <= manager.affinity_window
162+
# Find the matching key object
163+
idx = findfirst(k -> k.key == key_str, api_keys)
164+
if !isnothing(idx) && can_handle_tokens(api_keys[idx], estimated_tokens)
165+
return api_keys[idx]
166+
end
167+
end
168+
end
169+
170+
# 2) Simply choose the key with lowest current usage
171+
return argmin(k -> get_current_usage(k), api_keys)
172+
end
173+
174+
"""
175+
get_model_schema(model::String)
176+
177+
Get the schema for a model from the MODEL_REGISTRY.
178+
"""
179+
function get_model_schema(model::String)
180+
model_spec = get(PromptingTools.MODEL_REGISTRY, model, nothing)
181+
return isnothing(model_spec) ? OpenAISchema() : model_spec.schema
182+
end
183+
184+
"""
185+
get_model_schema(config::ModelConfig)
186+
187+
Get the schema from a ModelConfig.
188+
"""
189+
get_model_schema(config::ModelConfig) = isnothing(config.schema) ? OpenAISchema() : config.schema
190+
191+
"""
192+
initialize_from_env!(manager::APIKeyManager)
193+
194+
Initialize API keys from environment variables.
195+
"""
196+
function initialize_from_env!(manager::APIKeyManager)
197+
isempty(manager.schema_to_api_keys) || return # Already initialized
198+
199+
# Schema type to environment variable mapping
200+
schema_env_mapping = [
201+
(OpenAISchema, "OPENAI_API_KEY"),
202+
(CerebrasOpenAISchema, "CEREBRAS_API_KEY"),
203+
(MistralOpenAISchema, "MISTRAL_API_KEY"),
204+
(AnthropicSchema, "ANTHROPIC_API_KEY"),
205+
(GoogleSchema, "GOOGLE_API_KEY"),
206+
(GoogleOpenAISchema, "GOOGLE_API_KEY"),
207+
(GroqOpenAISchema, "GROQ_API_KEY"),
208+
(TogetherOpenAISchema, "TOGETHER_API_KEY"),
209+
(DeepSeekOpenAISchema, "DEEPSEEK_API_KEY"),
210+
(OpenRouterOpenAISchema, "OPENROUTER_API_KEY"),
211+
(SambaNovaOpenAISchema, "SAMBANOVA_API_KEY")
212+
]
213+
214+
# Additional environment variables that map to existing schemas
215+
additional_env_mapping = [
216+
# ("COHERE_API_KEY", OpenAISchema), #it is worng to assign them to OepnAISchema, also they are embedders and other things.
217+
# ("TAVILY_API_KEY", OpenAISchema),
218+
# ("JINA_API_KEY", OpenAISchema),
219+
# ("VOYAGE_API_KEY", OpenAISchema),
220+
# ("GEMINI_API_KEY", GoogleSchema)
221+
]
222+
223+
# Process all mappings
224+
for (schema_type, base_env_var) in schema_env_mapping
225+
keys = collect_env_keys(base_env_var)
226+
!isempty(keys) && add_api_keys!(manager, schema_type, keys)
227+
end
228+
229+
for (base_env_var, schema_type) in additional_env_mapping
230+
keys = collect_env_keys(base_env_var)
231+
!isempty(keys) && add_api_keys!(manager, schema_type, keys)
232+
end
233+
end
234+
235+
"""
236+
get_api_key_for_model(model::Union{String, ModelConfig},
237+
request_id::Union{String, Nothing} = nothing, prompt::AbstractString = "";
238+
manager::APIKeyManager = GLOBAL_API_KEY_MANAGER)
239+
240+
Get the appropriate API key for a model and request with proper rate limiting.
241+
"""
242+
function get_api_key_for_model(model::Union{String, ModelConfig},
243+
request_id::Union{String, Nothing} = nothing, prompt::AbstractString = "";
244+
manager::APIKeyManager = GLOBAL_API_KEY_MANAGER)
245+
initialize_from_env!(manager)
246+
schema = get_model_schema(model)
247+
schema_type = typeof(schema)
248+
est = LLMRateLimiters.estimate_tokens(prompt, CharCountDivTwo)
249+
key_obj = find_api_key_for_request(manager, schema_type, request_id, est)
250+
251+
isnothing(key_obj) && return nothing
252+
253+
# Update usage and affinity
254+
update_usage!(key_obj, est)
255+
!isnothing(request_id) && (manager.request_affinity[request_id] = (key_obj.key, time()))
256+
257+
return key_obj.key
258+
end

0 commit comments

Comments
 (0)