Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Your access token should be kept private. If you need to protect it in front-end
You can send inference requests to third-party providers with the inference client.

Currently, we support the following providers:
- [AlphaNeural](https://alphaneural.ai)
- [Fal.ai](https://fal.ai)
- [Featherless AI](https://featherless.ai)
- [Fireworks AI](https://fireworks.ai)
Expand Down Expand Up @@ -88,6 +89,7 @@ When authenticated with a Hugging Face access token, the request is routed throu
When authenticated with a third-party provider key, the request is made directly against that provider's inference API.

Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
- [AlphaNeural supported models](https://huggingface.co/api/partners/alphaneural/models)
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
- [Featherless AI supported models](https://huggingface.co/api/partners/featherless-ai/models)
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import * as Alphaneural from "../providers/alphaneural.js";
import * as Baseten from "../providers/baseten.js";
import * as Clarifai from "../providers/clarifai.js";
import * as BlackForestLabs from "../providers/black-forest-labs.js";
Expand Down Expand Up @@ -60,6 +61,10 @@ import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from
import { InferenceClientInputError } from "../errors.js";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
alphaneural: {
conversational: new Alphaneural.AlphaneuralConversationalTask(),
"text-generation": new Alphaneural.AlphaneuralTextGenerationTask(),
},
baseten: {
conversational: new Baseten.BasetenConversationalTask(),
},
Expand Down
73 changes: 73 additions & 0 deletions packages/inference/src/providers/alphaneural.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**
* See the registered mapping of HF model ID => AlphaNeural model ID here:
*
* https://huggingface.co/api/partners/alphaneural/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at AlphaNeural and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to AlphaNeural, please open an issue on the present repo
* and we will tag AlphaNeural team members.
*
* Thanks!
*/
import type { TextGenerationInput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
import type { BodyParams } from "../types.js";
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper.js";
import { omit } from "../utils/omit.js";
import { InferenceClientProviderOutputError } from "../errors.js";

const ALPHANEURAL_API_BASE_URL = "https://proxy.alfnrl.io";

interface AlphaneuralTextCompletionOutput {
choices: Array<{
text: string;
finish_reason: TextGenerationOutputFinishReason;
index: number;
}>;
model: string;
}

export class AlphaneuralConversationalTask extends BaseConversationalTask {
constructor() {
super("alphaneural", ALPHANEURAL_API_BASE_URL);
}
}

export class AlphaneuralTextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("alphaneural", ALPHANEURAL_API_BASE_URL);
}

override preparePayload(params: BodyParams<TextGenerationInput>): Record<string, unknown> {
return {
model: params.model,
...omit(params.args, ["inputs", "parameters"]),
...(params.args.parameters
? {
max_tokens: params.args.parameters.max_new_tokens,
...omit(params.args.parameters, "max_new_tokens"),
}
: undefined),
prompt: params.args.inputs,
};
}

override async getResponse(response: AlphaneuralTextCompletionOutput): Promise<TextGenerationOutput> {
if (
typeof response === "object" &&
"choices" in response &&
Array.isArray(response?.choices) &&
response.choices.length > 0 &&
typeof response.choices[0]?.text === "string"
) {
return {
generated_text: response.choices[0].text,
};
}
throw new InferenceClientProviderOutputError("Received malformed response from AlphaNeural text generation API");
}
}
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
* Example:
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
alphaneural: {},
baseten: {},
"black-forest-labs": {},
cerebras: {},
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface Options {
export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";

export const INFERENCE_PROVIDERS = [
"alphaneural",
"baseten",
"black-forest-labs",
"cerebras",
Expand Down Expand Up @@ -82,6 +83,7 @@ export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number];
* Whenever possible, InferenceProvider should == org namespace
*/
export const PROVIDERS_HUB_ORGS: Record<InferenceProvider, string> = {
alphaneural: "AlphaNeural",
baseten: "baseten",
"black-forest-labs": "black-forest-labs",
cerebras: "cerebras",
Expand Down
78 changes: 77 additions & 1 deletion packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ if (!env.HF_TOKEN) {
console.warn("Set HF_TOKEN in the env to run the tests for better rate limits");
}

describe.skip("InferenceClient", () => {
describe.skip("InferenceClient", () => {
// Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error.

describe("backward compatibility", () => {
Expand Down Expand Up @@ -1641,6 +1641,82 @@ describe.skip("InferenceClient", () => {
});
});

describe.concurrent(
"AlphaNeural",
() => {
const client = new InferenceClient(env.HF_ALPHANEURAL_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["alphaneural"] = {
"qwen/qwen3": {
provider: "alphaneural",
hfModelId: "qwen/qwen3",
providerId: "qwen3",
status: "live",
task: "conversational",
},
"Qwen/Qwen3-8B": {
provider: "alphaneural",
hfModelId: "Qwen/Qwen3-8B",
providerId: "qwen3",
status: "live",
task: "text-generation",
},
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "qwen/qwen3",
provider: "alphaneural",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "qwen/qwen3",
provider: "alphaneural",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});

it("textGeneration", async () => {
const res = await textGeneration({
accessToken: env.HF_ALPHANEURAL_KEY ?? "dummy",
model: "Qwen/Qwen3-8B",
provider: "alphaneural",
inputs: "The capital of France is",
parameters: {
temperature: 0,
max_tokens: 10,
},
});
expect(res).toBeDefined();
expect(res.generated_text).toBeDefined();
expect(typeof res.generated_text).toBe("string");
});
},
TIMEOUT
);

describe.concurrent(
"Fireworks",
() => {
Expand Down