Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions baml_language/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

161 changes: 161 additions & 0 deletions baml_language/crates/baml_builtins/baml/llm.baml
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,164 @@ function call_llm_function(function_name: string, args: map<string, unknown>) ->

throw "All orchestration steps failed";
}

// ============================================================================
// Streaming Orchestration
// ============================================================================

/// Execute a streaming call against a single primitive client.
///
/// Opens an SSE connection, accumulates provider-specific events, and emits
/// partial values + raw ticks via the engine's streaming callbacks.
/// Returns the final parsed result as an ExecutionResult.
function stream_primitive(
primitive: baml.llm.PrimitiveClient,
context: baml.llm.ExecutionContext,
) -> baml.llm.ExecutionResult {
let prompt = primitive.render_prompt(context.jinja_string, context.args);
let specialized = primitive.specialize_prompt(prompt);
let http_request = primitive.build_request_stream(specialized);
let sse = baml.http.fetch_sse(http_request);
let accumulator = primitive.new_stream_accumulator();
let return_type = baml.llm.get_return_type(context.function_name);

while (true) {
let events = sse.next();
if (events == null) { break; }
accumulator.add_events(events);
baml.stream.emit_tick(events);

let content = accumulator.content();
let parsed = primitive.partial_parse(content, return_type);
baml.stream.emit_partial(parsed);
}

sse.close();

let final_content = accumulator.content();
let final_value = primitive.parse(final_content, return_type);
baml.llm.ExecutionResult { ok: true, value: final_value }
}

/// Execute a streaming client with retry semantics, evaluating strategies lazily.
///
/// Mirrors `execute_client` but uses `stream_primitive` at the leaf level.
function execute_client_stream(
llm_client: baml.llm.Client,
context: baml.llm.ExecutionContext,
inherited_delay_ms: int,
) -> baml.llm.ExecutionResult {
match (llm_client.retry) {
null => baml.llm.execute_client_once_stream(
llm_client,
context,
inherited_delay_ms,
),
r: baml.llm.RetryPolicy => {
let current_delay = r.initial_delay_ms + 0.0;

for (let attempt = 0; attempt <= r.max_retries; attempt += 1) {
let attempt_delay = inherited_delay_ms;
if (attempt > 0) {
attempt_delay = baml.math.trunc(current_delay);
let next = current_delay * r.multiplier;
if (next > r.max_delay_ms + 0.0) {
current_delay = r.max_delay_ms + 0.0;
} else {
current_delay = next;
}
}
if (attempt == r.max_retries) {
attempt_delay = inherited_delay_ms;
}

let result = baml.llm.execute_client_once_stream(
llm_client,
context,
attempt_delay,
);

if (result.ok) {
return result;
}
}

baml.llm.ExecutionResult { ok: false, value: null }
}
}
}

/// Execute a single streaming attempt for a client (no retry expansion here).
///
/// Mirrors `execute_client_once` but calls `stream_primitive` for primitive clients.
function execute_client_once_stream(
llm_client: baml.llm.Client,
context: baml.llm.ExecutionContext,
active_delay_ms: int,
) -> baml.llm.ExecutionResult {
match (llm_client.client_type) {
baml.llm.ClientType.Primitive => {
let resolve_fn = baml.llm.resolve_client(llm_client.name);
let primitive = resolve_fn();

let result = baml.llm.stream_primitive(primitive, context);

if (result.ok) {
return result;
}

if (active_delay_ms > 0) {
baml.sys.sleep(active_delay_ms);
}

baml.llm.ExecutionResult { ok: false, value: null }
}

baml.llm.ClientType.Fallback => {
for (let sub in llm_client.sub_clients) {
let result = baml.llm.execute_client_stream(
sub,
context,
active_delay_ms,
);
if (result.ok) {
return result;
}
}
baml.llm.ExecutionResult { ok: false, value: null }
}

baml.llm.ClientType.RoundRobin => {
let idx = baml.llm.round_robin_next(llm_client.name) % llm_client.sub_clients.length();
baml.llm.execute_client_stream(
llm_client.sub_clients.at(idx),
context,
active_delay_ms,
)
}
}
}

/// Stream an LLM function end-to-end with full orchestration.
///
/// The streaming counterpart of `call_llm_function`. Uses the same client tree
/// resolution and retry/fallback/round-robin logic, but opens SSE connections
/// and emits partial values via streaming callbacks.
function stream_llm_function(function_name: string, args: map<string, unknown>) -> unknown throws string {
let jinja_string = baml.llm.get_jinja_template(function_name);
let llm_client = baml.llm.get_client(function_name);

let context = baml.llm.ExecutionContext {
jinja_string: jinja_string,
args: args,
function_name: function_name,
};

let result = baml.llm.execute_client_stream(llm_client, context, 0);

if (result.ok) {
return result.value;
}

throw "All streaming orchestration steps failed";
}
78 changes: 78 additions & 0 deletions baml_language/crates/baml_builtins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,29 @@ macro_rules! with_builtins {
#[sys_op]
#[throws(Io, Timeout)]
fn send(request: Request) -> Response;

/// An SSE (Server-Sent Events) stream.
#[builtin]
struct SseStream {
private _handle: ResourceHandle,
url: String,

/// Get the next batch of SSE events, or null if stream is done.
#[sys_op]
#[throws(Io, Timeout)]
fn next(self: SseStream) -> Option<String>;

/// Close the SSE stream.
#[sys_op]
#[throws(Io)]
fn close(self: SseStream);
}

/// Open an SSE connection by sending an HTTP request.
/// Returns an SseStream that can be iterated.
#[sys_op]
#[throws(Io, Timeout)]
fn fetch_sse(request: Request) -> SseStream;
}

// =====================================================================
Expand Down Expand Up @@ -418,6 +441,21 @@ macro_rules! with_builtins {
#[sys_op]
#[throws(LlmClient)]
fn parse(self: PrimitiveClient, http_response_body: String, type_def: Type) -> Any;

/// Build an HTTP request with streaming enabled.
#[sys_op]
#[throws(LlmClient)]
fn build_request_stream(self: PrimitiveClient, prompt: PromptAst) -> Request;

/// Create a new stream accumulator for this primitive client.
#[sys_op]
#[throws(LlmClient)]
fn new_stream_accumulator(self: PrimitiveClient) -> StreamAccumulator;

/// Parse partial content (string-only for now).
#[sys_op]
#[throws(LlmClient)]
fn partial_parse(self: PrimitiveClient, content: String, type_def: Type) -> String;
}

/// Get the Jinja template for an LLM function.
Expand Down Expand Up @@ -472,6 +510,46 @@ macro_rules! with_builtins {
#[throws(InvalidArgument)]
#[uses(engine_ctx)]
fn get_return_type(function_name: String) -> Type;

/// A stream accumulator that extracts content from SSE events.
#[builtin]
struct StreamAccumulator {
private _handle: ResourceHandle,

/// Add a batch of SSE event data (JSON strings) to the accumulator.
#[sys_op]
#[throws(LlmClient)]
fn add_events(self: StreamAccumulator, events: String);

/// Get the accumulated content so far.
#[sys_op]
#[throws(LlmClient)]
fn content(self: StreamAccumulator) -> String;

/// Check if the stream is done.
#[sys_op]
#[throws(LlmClient)]
fn is_done(self: StreamAccumulator) -> bool;
}

}

// =====================================================================
// Streaming operations
// =====================================================================
mod stream {
/// Emit a parsed partial value to the stream callback.
/// Pass empty string to indicate no value.
#[sys_op]
#[throws(Io)]
#[uses(engine_ctx)]
fn emit_partial(value: String);

/// Emit raw SSE events to the tick callback.
#[sys_op]
#[throws(Io)]
#[uses(engine_ctx)]
fn emit_tick(events: String);
}
}

Expand Down
Loading
Loading