Skip to content

Commit 1ae26c3

Browse files
authored
Merge pull request #777 from ai16z/shaw/refactor-image-interface
Fix: Refactor image interface and update to move llama cloud -> together provider
2 parents dadef5b + 45d3c8f commit 1ae26c3

File tree

7 files changed

+139
-74
lines changed

7 files changed

+139
-74
lines changed

agent/src/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ export function getTokenForProvider(
197197
settings.ETERNALAI_API_KEY
198198
);
199199
case ModelProviderName.LLAMACLOUD:
200+
case ModelProviderName.TOGETHER:
200201
return (
201202
character.settings?.secrets?.LLAMACLOUD_API_KEY ||
202203
settings.LLAMACLOUD_API_KEY ||

packages/core/src/generation.ts

+30-23
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,25 @@ export async function generateText({
7878

7979
// if runtime.getSetting("LLAMACLOUD_MODEL_LARGE") is true and modelProvider is LLAMACLOUD, then use the large model
8080
if (
81-
runtime.getSetting("LLAMACLOUD_MODEL_LARGE") &&
82-
provider === ModelProviderName.LLAMACLOUD
81+
(runtime.getSetting("LLAMACLOUD_MODEL_LARGE") &&
82+
provider === ModelProviderName.LLAMACLOUD) ||
83+
(runtime.getSetting("TOGETHER_MODEL_LARGE") &&
84+
provider === ModelProviderName.TOGETHER)
8385
) {
84-
model = runtime.getSetting("LLAMACLOUD_MODEL_LARGE");
86+
model =
87+
runtime.getSetting("LLAMACLOUD_MODEL_LARGE") ||
88+
runtime.getSetting("TOGETHER_MODEL_LARGE");
8589
}
8690

8791
if (
88-
runtime.getSetting("LLAMACLOUD_MODEL_SMALL") &&
89-
provider === ModelProviderName.LLAMACLOUD
92+
(runtime.getSetting("LLAMACLOUD_MODEL_SMALL") &&
93+
provider === ModelProviderName.LLAMACLOUD) ||
94+
(runtime.getSetting("TOGETHER_MODEL_SMALL") &&
95+
provider === ModelProviderName.TOGETHER)
9096
) {
91-
model = runtime.getSetting("LLAMACLOUD_MODEL_SMALL");
97+
model =
98+
runtime.getSetting("LLAMACLOUD_MODEL_SMALL") ||
99+
runtime.getSetting("TOGETHER_MODEL_SMALL");
92100
}
93101

94102
elizaLogger.info("Selected model:", model);
@@ -120,7 +128,8 @@ export async function generateText({
120128
case ModelProviderName.ETERNALAI:
121129
case ModelProviderName.ALI_BAILIAN:
122130
case ModelProviderName.VOLENGINE:
123-
case ModelProviderName.LLAMACLOUD: {
131+
case ModelProviderName.LLAMACLOUD:
132+
case ModelProviderName.TOGETHER: {
124133
elizaLogger.debug("Initializing OpenAI model.");
125134
const openai = createOpenAI({ apiKey, baseURL: endpoint });
126135

@@ -806,12 +815,6 @@ export const generateImage = async (
806815
data?: string[];
807816
error?: any;
808817
}> => {
809-
const { prompt, width, height } = data;
810-
let { count } = data;
811-
if (!count) {
812-
count = 1;
813-
}
814-
815818
const model = getModel(runtime.imageModelProvider, ModelClass.IMAGE);
816819
const modelSettings = models[runtime.imageModelProvider].imageSettings;
817820

@@ -866,16 +869,19 @@ export const generateImage = async (
866869
const imageURL = await response.json();
867870
return { success: true, data: [imageURL] };
868871
} else if (
872+
runtime.imageModelProvider === ModelProviderName.TOGETHER ||
873+
// for backwards compat
869874
runtime.imageModelProvider === ModelProviderName.LLAMACLOUD
870875
) {
871876
const together = new Together({ apiKey: apiKey as string });
877+
// Fix: steps 4 is for schnell; 28 is for dev.
872878
const response = await together.images.create({
873879
model: "black-forest-labs/FLUX.1-schnell",
874-
prompt,
875-
width,
876-
height,
880+
prompt: data.prompt,
881+
width: data.width,
882+
height: data.height,
877883
steps: modelSettings?.steps ?? 4,
878-
n: count,
884+
n: data.count,
879885
});
880886
const urls: string[] = [];
881887
for (let i = 0; i < response.data.length; i++) {
@@ -902,11 +908,11 @@ export const generateImage = async (
902908

903909
// Prepare the input parameters according to their schema
904910
const input = {
905-
prompt: prompt,
911+
prompt: data.prompt,
906912
image_size: "square" as const,
907913
num_inference_steps: modelSettings?.steps ?? 50,
908-
guidance_scale: 3.5,
909-
num_images: count,
914+
guidance_scale: data.guidanceScale || 3.5,
915+
num_images: data.count,
910916
enable_safety_checker: true,
911917
output_format: "png" as const,
912918
seed: data.seed ?? 6252023,
@@ -945,7 +951,7 @@ export const generateImage = async (
945951
const base64s = await Promise.all(base64Promises);
946952
return { success: true, data: base64s };
947953
} else {
948-
let targetSize = `${width}x${height}`;
954+
let targetSize = `${data.width}x${data.height}`;
949955
if (
950956
targetSize !== "1024x1024" &&
951957
targetSize !== "1792x1024" &&
@@ -956,9 +962,9 @@ export const generateImage = async (
956962
const openai = new OpenAI({ apiKey: apiKey as string });
957963
const response = await openai.images.generate({
958964
model,
959-
prompt,
965+
prompt: data.prompt,
960966
size: targetSize as "1024x1024" | "1792x1024" | "1024x1792",
961-
n: count,
967+
n: data.count,
962968
response_format: "b64_json",
963969
});
964970
const base64s = response.data.map(
@@ -1157,6 +1163,7 @@ export async function handleProvider(
11571163
case ModelProviderName.ALI_BAILIAN:
11581164
case ModelProviderName.VOLENGINE:
11591165
case ModelProviderName.LLAMACLOUD:
1166+
case ModelProviderName.TOGETHER:
11601167
return await handleOpenAI(options);
11611168
case ModelProviderName.ANTHROPIC:
11621169
return await handleAnthropic(options);

packages/core/src/models.ts

+21
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,27 @@ export const models: Models = {
128128
[ModelClass.IMAGE]: "black-forest-labs/FLUX.1-schnell",
129129
},
130130
},
131+
[ModelProviderName.TOGETHER]: {
132+
settings: {
133+
stop: [],
134+
maxInputTokens: 128000,
135+
maxOutputTokens: 8192,
136+
repetition_penalty: 0.4,
137+
temperature: 0.7,
138+
},
139+
imageSettings: {
140+
steps: 4,
141+
},
142+
endpoint: "https://api.together.ai/v1",
143+
model: {
144+
[ModelClass.SMALL]: "meta-llama/Llama-3.2-3B-Instruct-Turbo",
145+
[ModelClass.MEDIUM]: "meta-llama-3.1-8b-instruct",
146+
[ModelClass.LARGE]: "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
147+
[ModelClass.EMBEDDING]:
148+
"togethercomputer/m2-bert-80M-32k-retrieval",
149+
[ModelClass.IMAGE]: "black-forest-labs/FLUX.1-schnell",
150+
},
151+
},
131152
[ModelProviderName.LLAMALOCAL]: {
132153
settings: {
133154
stop: ["<|eot_id|>", "<|eom_id|>"],

packages/core/src/tests/generation.test.ts

+56-46
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import { describe, expect, it, vi, beforeEach } from "vitest";
22
import { ModelProviderName, IAgentRuntime } from "../types";
33
import { models } from "../models";
4-
import { generateText, generateTrueOrFalse, splitChunks, trimTokens } from "../generation";
4+
import {
5+
generateText,
6+
generateTrueOrFalse,
7+
splitChunks,
8+
trimTokens,
9+
} from "../generation";
510
import type { TiktokenModel } from "js-tiktoken";
611

712
// Mock the elizaLogger
@@ -42,6 +47,8 @@ describe("Generation", () => {
4247
getSetting: vi.fn().mockImplementation((key: string) => {
4348
if (key === "LLAMACLOUD_MODEL_LARGE") return false;
4449
if (key === "LLAMACLOUD_MODEL_SMALL") return false;
50+
if (key === "TOGETHER_MODEL_LARGE") return false;
51+
if (key === "TOGETHER_MODEL_SMALL") return false;
4552
return undefined;
4653
}),
4754
} as unknown as IAgentRuntime;
@@ -122,53 +129,56 @@ describe("Generation", () => {
122129
});
123130
});
124131

125-
describe("trimTokens", () => {
126-
const model = "gpt-4" as TiktokenModel;
127-
128-
it("should return empty string for empty input", () => {
129-
const result = trimTokens("", 100, model);
130-
expect(result).toBe("");
131-
});
132-
133-
it("should throw error for negative maxTokens", () => {
134-
expect(() => trimTokens("test", -1, model)).toThrow("maxTokens must be positive");
135-
});
136-
137-
it("should return unchanged text if within token limit", () => {
138-
const shortText = "This is a short text";
139-
const result = trimTokens(shortText, 10, model);
140-
expect(result).toBe(shortText);
141-
});
142-
143-
it("should truncate text to specified token limit", () => {
144-
// Using a longer text that we know will exceed the token limit
145-
const longText = "This is a much longer text that will definitely exceed our very small token limit and need to be truncated to fit within the specified constraints."
146-
const result = trimTokens(longText, 5, model);
147-
148-
// The exact result will depend on the tokenizer, but we can verify:
149-
// 1. Result is shorter than original
150-
expect(result.length).toBeLessThan(longText.length);
151-
// 2. Result is not empty
152-
expect(result.length).toBeGreaterThan(0);
153-
// 3. Result is a proper substring of the original text
154-
expect(longText.includes(result)).toBe(true);
155-
});
156-
157-
it("should handle non-ASCII characters", () => {
158-
const unicodeText = "Hello 👋 World 🌍";
159-
const result = trimTokens(unicodeText, 5, model);
160-
expect(result.length).toBeGreaterThan(0);
161-
});
162-
163-
it("should handle multiline text", () => {
164-
const multilineText = `Line 1
132+
describe("trimTokens", () => {
133+
const model = "gpt-4" as TiktokenModel;
134+
135+
it("should return empty string for empty input", () => {
136+
const result = trimTokens("", 100, model);
137+
expect(result).toBe("");
138+
});
139+
140+
it("should throw error for negative maxTokens", () => {
141+
expect(() => trimTokens("test", -1, model)).toThrow(
142+
"maxTokens must be positive"
143+
);
144+
});
145+
146+
it("should return unchanged text if within token limit", () => {
147+
const shortText = "This is a short text";
148+
const result = trimTokens(shortText, 10, model);
149+
expect(result).toBe(shortText);
150+
});
151+
152+
it("should truncate text to specified token limit", () => {
153+
// Using a longer text that we know will exceed the token limit
154+
const longText =
155+
"This is a much longer text that will definitely exceed our very small token limit and need to be truncated to fit within the specified constraints.";
156+
const result = trimTokens(longText, 5, model);
157+
158+
// The exact result will depend on the tokenizer, but we can verify:
159+
// 1. Result is shorter than original
160+
expect(result.length).toBeLessThan(longText.length);
161+
// 2. Result is not empty
162+
expect(result.length).toBeGreaterThan(0);
163+
// 3. Result is a proper substring of the original text
164+
expect(longText.includes(result)).toBe(true);
165+
});
166+
167+
it("should handle non-ASCII characters", () => {
168+
const unicodeText = "Hello 👋 World 🌍";
169+
const result = trimTokens(unicodeText, 5, model);
170+
expect(result.length).toBeGreaterThan(0);
171+
});
172+
173+
it("should handle multiline text", () => {
174+
const multilineText = `Line 1
165175
Line 2
166176
Line 3
167177
Line 4
168178
Line 5`;
169-
const result = trimTokens(multilineText, 5, model);
170-
expect(result.length).toBeGreaterThan(0);
171-
expect(result.length).toBeLessThan(multilineText.length);
172-
});
173-
});
179+
const result = trimTokens(multilineText, 5, model);
180+
expect(result.length).toBeGreaterThan(0);
181+
expect(result.length).toBeLessThan(multilineText.length);
182+
});
183+
});
174184
});

packages/core/src/types.ts

+2
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ export type Models = {
192192
[ModelProviderName.GROK]: Model;
193193
[ModelProviderName.GROQ]: Model;
194194
[ModelProviderName.LLAMACLOUD]: Model;
195+
[ModelProviderName.TOGETHER]: Model;
195196
[ModelProviderName.LLAMALOCAL]: Model;
196197
[ModelProviderName.GOOGLE]: Model;
197198
[ModelProviderName.CLAUDE_VERTEX]: Model;
@@ -216,6 +217,7 @@ export enum ModelProviderName {
216217
GROK = "grok",
217218
GROQ = "groq",
218219
LLAMACLOUD = "llama_cloud",
220+
TOGETHER = "together",
219221
LLAMALOCAL = "llama_local",
220222
GOOGLE = "google",
221223
CLAUDE_VERTEX = "claude_vertex",

packages/plugin-image-generation/src/index.ts

+29-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import { generateImage } from "@ai16z/eliza";
1111

1212
import fs from "fs";
1313
import path from "path";
14-
import { validateImageGenConfig } from "./enviroment";
14+
import { validateImageGenConfig } from "./environment";
1515

1616
export function saveBase64Image(base64Data: string, filename: string): string {
1717
// Create generatedImages directory if it doesn't exist
@@ -97,7 +97,17 @@ const imageGeneration: Action = {
9797
runtime: IAgentRuntime,
9898
message: Memory,
9999
state: State,
100-
options: any,
100+
options: {
101+
width?: number;
102+
height?: number;
103+
count?: number;
104+
negativePrompt?: string;
105+
numIterations?: number;
106+
guidanceScale?: number;
107+
seed?: number;
108+
modelId?: string;
109+
jobId?: string;
110+
},
101111
callback: HandlerCallback
102112
) => {
103113
elizaLogger.log("Composing state for message:", message);
@@ -116,9 +126,23 @@ const imageGeneration: Action = {
116126
const images = await generateImage(
117127
{
118128
prompt: imagePrompt,
119-
width: 1024,
120-
height: 1024,
121-
count: 1,
129+
width: options.width || 1024,
130+
height: options.height || 1024,
131+
...(options.count != null ? { count: options.count || 1 } : {}),
132+
...(options.negativePrompt != null
133+
? { negativePrompt: options.negativePrompt }
134+
: {}),
135+
...(options.numIterations != null
136+
? { numIterations: options.numIterations }
137+
: {}),
138+
...(options.guidanceScale != null
139+
? { guidanceScale: options.guidanceScale }
140+
: {}),
141+
...(options.seed != null ? { seed: options.seed } : {}),
142+
...(options.modelId != null
143+
? { modelId: options.modelId }
144+
: {}),
145+
...(options.jobId != null ? { jobId: options.jobId } : {}),
122146
},
123147
runtime
124148
);

0 commit comments

Comments
 (0)