Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions test/src/ai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use worker::{
models::llama_4_scout_17b_16e_instruct::Llama4Scout17b16eInstruct,
worker_sys::AiTextGenerationInput, Env, Request, Response, Result,
};

use crate::SomeSharedData;

const AI_TEST: &str = "AI_TEST";

#[worker::send]
pub async fn simple_ai_text_generation(
_: Request,
env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let ai = env
.ai(AI_TEST)?
.run::<Llama4Scout17b16eInstruct>(
AiTextGenerationInput::new()
.set_prompt("What is the answer to life the universe and everything?"),
)
.await?;
Response::ok(ai.get_response().unwrap_or_default())
}

#[worker::send]
pub async fn streaming_ai_text_generation(
_: Request,
env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let stream = env
.ai(AI_TEST)?
.run_streaming::<Llama4Scout17b16eInstruct>(
AiTextGenerationInput::new()
.set_prompt("What is the answer to life the universe and everything?"),
)
.await?;

Response::from_stream(stream)
}
1 change: 1 addition & 0 deletions test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use worker::{console_log, event, js_sys, wasm_bindgen, Env, Result};
#[cfg(not(feature = "http"))]
use worker::{Request, Response};

mod ai;
mod alarm;
mod analytics_engine;
mod assets;
Expand Down
8 changes: 5 additions & 3 deletions test/src/router.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable, fetch,
form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket, sql_counter,
sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE,
ai, alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable,
fetch, form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket,
sql_counter, sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE,
};
#[cfg(feature = "http")]
use std::convert::TryInto;
Expand Down Expand Up @@ -112,6 +112,8 @@ macro_rules! add_route (

macro_rules! add_routes (
($obj:ident) => {
add_route!($obj, get, "/ai", ai::simple_ai_text_generation);
add_route!($obj, get, "/ai/streaming", ai::streaming_ai_text_generation);
add_route!($obj, get, sync, "/request", request::handle_a_request);
add_route!($obj, get, "/analytics-engine", analytics_engine::handle_analytics_event);
add_route!($obj, get, "/async-request", request::handle_async_request);
Expand Down
11 changes: 11 additions & 0 deletions test/tests/ai.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { describe, expect, test } from "vitest";
import { mf, mfUrl } from "./mf";

async function runTest() {
let normal_response = await mf.dispatchFetch(`${mfUrl}ai`);
expect(normal_response.status).toBe(200);

let streaming_response = await mf.dispatchFetch(`${mfUrl}ai/streaming`);
expect(streaming_response.status).toBe(200);
}
describe("ai", runTest);
25 changes: 14 additions & 11 deletions test/wrangler.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name = "testing-rust-worker"
workers_dev = true
compatibility_date = "2025-09-23" # required
compatibility_date = "2025-09-23" # required
main = "build/worker/shim.mjs"

kv_namespaces = [
{ binding = "SOME_NAMESPACE", id = "SOME_NAMESPACE", preview_id = "SOME_NAMESPACE" },
{ binding = "FILE_SIZES", id = "FILE_SIZES", preview_id = "FILE_SIZES" },
{ binding = "SOME_NAMESPACE", id = "SOME_NAMESPACE", preview_id = "SOME_NAMESPACE" },
{ binding = "FILE_SIZES", id = "FILE_SIZES", preview_id = "FILE_SIZES" },
]

[vars]
Expand All @@ -22,14 +22,14 @@ service = "remote-service"

[durable_objects]
bindings = [
{ name = "COUNTER", class_name = "Counter" },
{ name = "ALARM", class_name = "AlarmObject" },
{ name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" },
{ name = "AUTO", class_name = "AutoResponseObject" },
{ name = "SQL_COUNTER", class_name = "SqlCounter" },
{ name = "SQL_ITERATOR", class_name = "SqlIterator" },
{ name = "MY_CLASS", class_name = "MyClass" },
{ name = "ECHO_CONTAINER", class_name = "EchoContainer" },
{ name = "COUNTER", class_name = "Counter" },
{ name = "ALARM", class_name = "AlarmObject" },
{ name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" },
{ name = "AUTO", class_name = "AutoResponseObject" },
{ name = "SQL_COUNTER", class_name = "SqlCounter" },
{ name = "SQL_ITERATOR", class_name = "SqlIterator" },
{ name = "MY_CLASS", class_name = "MyClass" },
{ name = "ECHO_CONTAINER", class_name = "EchoContainer" },
]

[[analytics_engine_datasets]]
Expand Down Expand Up @@ -84,3 +84,6 @@ secret_name = "secret-name"
class_name = "EchoContainer"
image = "./container-echo/Dockerfile"
max_instances = 1

[ai]
binding = "AI_TEST"
1 change: 1 addition & 0 deletions worker-sys/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ pub use tls_client_auth::*;
pub use version::*;
pub use websocket_pair::*;
pub use websocket_request_response_pair::*;
pub mod utils;
231 changes: 230 additions & 1 deletion worker-sys/src/types/ai.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use js_sys::Promise;
use std::iter::FromIterator;

use js_sys::{Array, Promise};
use wasm_bindgen::prelude::*;

use crate::typed_array;

#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(extends=::js_sys::Object, js_name=Ai)]
Expand All @@ -10,3 +14,228 @@ extern "C" {
#[wasm_bindgen(structural, method, js_class=Ai, js_name=run)]
pub fn run(this: &Ai, model: &str, input: JsValue) -> Promise;
}

typed_array!(RoleScopedChatInputArray, RoleScopedChatInput);

impl RoleScopedChatInputArray {
pub fn custom_role(&self, role: &str, content: &str) -> &Self {
let message = RoleScopedChatInput::new();
message.set_role_inner(role);
message.set_content_inner(content);
self.push(&message);
self
}

pub fn user(self, content: &str) -> Self {
self.custom_role("user", content);
self
}

pub fn assistant(self, content: &str) -> Self {
self.custom_role("assistant", content);
self
}

pub fn system(self, content: &str) -> Self {
self.custom_role("system", content);
self
}

pub fn tool(self, content: &str) -> Self {
self.custom_role("tool", content);
self
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than putting these on the array, let's put these on the message builder itself.

So we would rather use directly:

messages.push(
    RoleScopedChatInput::builder()
        .role("user")
        .content("hello")
        .build()?
);

We may yet go either way on array generics - wrapper types or not. I don't want to assume one direction yet though. We definitely want builders at least.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually a generic array builder would be a useful pattern I suppose, but the trick with this would be constructing a builder API that can in theory support optimized creation of the whole array in a single buffer serialization (as an optimization, we don't actually have to do it yet).

That is, something more like:

messages.builder()
  .push_builder(|b| b.role("user").content("hello"))
  .push_builder(|b| b.role("user").content("hello"))
  .build()

where the push_builder method on TypedArray<T> is works for T impl Builder for some definition of a builder trait?

Or something along those lines. Let me know if that make sense where I'm trying to go with this?

Copy link
Contributor Author

@parzivale parzivale Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll create a generic array builder. I'm a little confused with the single buffer serialization optimization as the items in the array are jsvalues and are not serializeable

Copy link
Contributor Author

@parzivale parzivale Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also currently the generic array isn't really a generic array just a helper macro which generates an array with typed helper methods, I can refactor to make a proper generic array. Though that might come with complications as I would need to cast every time an access is made rather than just use wasm-bindgen to define the return type

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In due course we will be getting a generic array upstream, just using that term here for now. I understand your macro is a custom type and I'm fine with this approach unflagged with the assumption that the change won't be too destructive down the line.

If we could do a generic array that would in theory be preferable (TypedArray<T> wrapping Array), but I've set enough requirements on this PR here I feel! Separately I would be interested to dig into your cast concern more as I didn't follow that, and it would help the upstream conversation there.

As for the optimization - the idea is that every call from Wasm to JS incurs a performance overhead. If we can batch those calls through an optimization as a single call that should be more performant, turning an array of N values from N calls out to JS to 1. Again not a hard constraint, just sharing the design space thinking.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically just that a Rust builder could theoretically implement such optimizations, but we don't actually need to do them now of course.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the casting concerns that was due to me misunderstanding what you wanted and assuming you wanted a new array wrapper which looked like this

pub struct Array<T> {
   array: js_sys::Array,
   phantom: PhantomData<T>
}

Which wrapped all the associated methods on array with an extra casting step to T.

Regarding the optimization, I see what you mean I got a little confused due to the serialized wording and assumed you wanted something like a serde_wasm_bindgen call as part of the array builder build call.

Copy link
Contributor Author

@parzivale parzivale Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've pushed an implementation of this, I tried to keep the bulk move to js in mind when constructing the builder, the push functionality to the array builder should probably have some refinement but I wanted to get it in front of your eyes first. Finally I would like to move the typedArray macro over to a proc macro before merging as I'm reaching for dirtier and dirtier hacks to avoid polluting the name space (that's why the exported typed array is referenced as RoleScopedChatInputArray::RoleScopedChatInputArray


#[wasm_bindgen]
extern "C" {
# [wasm_bindgen (extends = :: js_sys :: Object)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub type RoleScopedChatInput;

#[wasm_bindgen(constructor, js_class = Object)]
fn new() -> RoleScopedChatInput;

#[wasm_bindgen(method, setter = "role")]
fn set_role_inner(this: &RoleScopedChatInput, role: &str);
#[wasm_bindgen(method, getter = "role")]
fn get_role_inner(this: &RoleScopedChatInput) -> Option<String>;

#[wasm_bindgen(method, setter = "content")]
fn set_content_inner(this: &RoleScopedChatInput, content: &str);
#[wasm_bindgen(method, getter = "content")]
fn get_content_inner(this: &RoleScopedChatInput) -> Option<String>;
}

#[derive(Default, Debug)]
pub enum Role {
#[default]
User,
Assistant,
System,
Tool,
Any(String),
}

impl RoleScopedChatInput {
pub fn get_role(&self) -> Role {
match self.get_role_inner().as_deref() {
Some("user") => Role::User,
Some("assistant") => Role::Assistant,
Some("system") => Role::System,
Some("tool") => Role::Tool,
Some(any) => Role::Any(any.to_owned()),
None => Role::default(),
}
}

pub fn get_content(&self) -> String {
self.get_content_inner().unwrap_or_default()
}
}

#[wasm_bindgen]
extern "C" {
# [wasm_bindgen (extends = :: js_sys :: Object)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub type AiTextGenerationInput;

#[wasm_bindgen(constructor, js_class = Object)]
pub fn new() -> AiTextGenerationInput;

#[wasm_bindgen(method, setter = "prompt")]
fn set_prompt_inner(this: &AiTextGenerationInput, prompt: &str);
#[wasm_bindgen(method, getter = "prompt")]
pub fn get_prompt(this: &AiTextGenerationInput) -> Option<String>;

#[wasm_bindgen(method, setter = "raw")]
fn set_raw_inner(this: &AiTextGenerationInput, raw: bool);
#[wasm_bindgen(method, getter = "raw")]
pub fn get_raw(this: &AiTextGenerationInput) -> Option<bool>;

#[wasm_bindgen(method, setter = "max_tokens")]
fn set_max_tokens_inner(this: &AiTextGenerationInput, max_tokens: u32);
#[wasm_bindgen(method, getter = "max_tokens")]
pub fn get_max_tokens(this: &AiTextGenerationInput) -> Option<u32>;

#[wasm_bindgen(method, setter = "temperature")]
fn set_temperature_inner(this: &AiTextGenerationInput, temperature: f32);
#[wasm_bindgen(method, getter = "temperature")]
pub fn get_temperature(this: &AiTextGenerationInput) -> Option<f32>;

#[wasm_bindgen(method, setter = "top_p")]
fn set_top_p_inner(this: &AiTextGenerationInput, top_p: f32);
#[wasm_bindgen(method, getter = "top_p")]
pub fn get_top_p(this: &AiTextGenerationInput) -> Option<f32>;

#[wasm_bindgen(method, setter = "top_k")]
fn set_top_k_inner(this: &AiTextGenerationInput, top_p: u32);
#[wasm_bindgen(method, getter = "top_k")]
pub fn get_top_k(this: &AiTextGenerationInput) -> Option<u32>;

#[wasm_bindgen(method, setter = "seed")]
fn set_seed_inner(this: &AiTextGenerationInput, seed: u64);
#[wasm_bindgen(method, getter = "seed")]
pub fn get_seed(this: &AiTextGenerationInput) -> Option<u64>;

#[wasm_bindgen(method, setter = "repetition_penalty")]
fn set_repetition_penalty_inner(this: &AiTextGenerationInput, repetition_penalty: f32);
#[wasm_bindgen(method, getter = "repetition_penalty")]
pub fn get_repetition_penalty(this: &AiTextGenerationInput) -> Option<f32>;

#[wasm_bindgen(method, setter = "frequency_penalty")]
fn set_frequency_penalty_inner(this: &AiTextGenerationInput, frequency_penalty: f32);
#[wasm_bindgen(method, getter = "frequency_penalty")]
pub fn get_frequency_penalty(this: &AiTextGenerationInput) -> Option<f32>;

#[wasm_bindgen(method, setter = "presence_penalty")]
fn set_presence_penalty_inner(this: &AiTextGenerationInput, presence_penalty: f32);
#[wasm_bindgen(method, getter = "presence_penalty")]
pub fn get_presence_penalty(this: &AiTextGenerationInput) -> Option<f32>;

#[wasm_bindgen(method, setter = "messages")]
fn set_messages_inner(this: &AiTextGenerationInput, messages: Array);
#[wasm_bindgen(method, getter = "messages")]
pub fn get_messages(this: &AiTextGenerationInput) -> Option<Vec<RoleScopedChatInput>>;

}

impl AiTextGenerationInput {
pub fn set_prompt(self, prompt: &str) -> Self {
self.set_prompt_inner(prompt);
self
}

pub fn set_raw(self, raw: bool) -> Self {
self.set_raw_inner(raw);
self
}

pub fn set_max_tokens(self, max_tokens: u32) -> Self {
self.set_max_tokens_inner(max_tokens);
self
}

pub fn set_temperature(self, temperature: f32) -> Self {
self.set_temperature_inner(temperature);
self
}

pub fn set_top_p(self, top_p: f32) -> Self {
self.set_top_p_inner(top_p);
self
}

pub fn set_top_k(self, top_k: u32) -> Self {
self.set_top_k_inner(top_k);
self
}

pub fn set_seed(self, seed: u64) -> Self {
self.set_seed_inner(seed);
self
}

pub fn set_repetition_penalty(self, repetition_penalty: f32) -> Self {
self.set_repetition_penalty_inner(repetition_penalty);
self
}

pub fn set_frequency_penalty(self, frequency_penalty: f32) -> Self {
self.set_frequency_penalty_inner(frequency_penalty);
self
}

pub fn set_presence_penalty(self, presence_penalty: f32) -> Self {
self.set_presence_penalty_inner(presence_penalty);
self
}

pub fn set_messages(self, messages: &[RoleScopedChatInput]) -> Self {
self.set_messages_inner(Array::from_iter(messages));
self
}
}

#[wasm_bindgen]
extern "C" {
# [wasm_bindgen (extends = :: js_sys :: Object)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub type AiTextGenerationOutput;

#[wasm_bindgen(constructor, js_class = Object)]
pub fn new() -> AiTextGenerationOutput;

#[wasm_bindgen(method, getter = "response")]
pub fn get_response(this: &AiTextGenerationOutput) -> Option<String>;

}

impl From<AiTextGenerationOutput> for Vec<u8> {
fn from(value: AiTextGenerationOutput) -> Self {
value
.get_response()
.map(|text| text.into_bytes())
.unwrap_or_default()
}
}
1 change: 1 addition & 0 deletions worker-sys/src/types/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod typed_array;
Loading