Skip to content

Commit 4507558

Browse files
authored
refactor: [ENG-3156] Determine response format in provider class (#4858)
* add determineResponseFormat to provider classes * add err msg * fixg
1 parent 64003f9 commit 4507558

File tree

8 files changed

+54
-25
lines changed

8 files changed

+54
-25
lines changed

packages/__tests__/cost/providers/vertex.test.ts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ describe("VertexProvider", () => {
241241

242242
const parsed = JSON.parse(body);
243243
expect(parsed.anthropic_version).toBe("vertex-2023-10-16");
244-
expect(parsed.model).toBe("claude-3-haiku");
245244
expect(parsed.anthropic_content).toBe(true);
246245
});
247246

@@ -260,7 +259,6 @@ describe("VertexProvider", () => {
260259

261260
const parsed = JSON.parse(body);
262261
expect(parsed.anthropic_version).toBe("vertex-2023-10-16");
263-
expect(parsed.model).toBe("claude-3-sonnet");
264262
expect(parsed.should_not_appear).toBeUndefined();
265263
});
266264
});

packages/cost/models/provider-helpers.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import type {
77
AuthResult,
88
RequestBodyContext,
99
RequestParams,
10+
ResponseFormat,
1011
} from "./types";
1112
import { providers, ModelProviderName } from "./providers";
1213
import { BaseProvider } from "./providers/base";
@@ -265,3 +266,17 @@ export async function buildErrorMessage(
265266

266267
return ok(await provider.buildErrorMessage(response));
267268
}
269+
270+
export function determineResponseFormat(endpoint: Endpoint): Result<ResponseFormat> {
271+
const providerResult = getProvider(endpoint.provider);
272+
if (providerResult.error) {
273+
return err(providerResult.error);
274+
}
275+
276+
const provider = providerResult.data;
277+
if (!provider) {
278+
return err(`Provider data is null for: ${endpoint.provider}`);
279+
}
280+
281+
return ok(provider.determineResponseFormat(endpoint));
282+
}

packages/cost/models/providers/anthropic.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type {
55
Endpoint,
66
RequestBodyContext,
77
RequestParams,
8+
ResponseFormat,
89
} from "../types";
910

1011
export class AnthropicProvider extends BaseProvider {
@@ -41,4 +42,11 @@ export class AnthropicProvider extends BaseProvider {
4142
const anthropicBody = context.toAnthropic(context.parsedBody, endpoint.providerModelId);
4243
return JSON.stringify(anthropicBody);
4344
}
45+
46+
determineResponseFormat(endpoint: Endpoint): ResponseFormat {
47+
if (endpoint.author === "anthropic" || endpoint.providerModelId.includes("claude-")) {
48+
return "ANTHROPIC";
49+
}
50+
return "OPENAI";
51+
}
4452
}

packages/cost/models/providers/base.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import type {
66
Endpoint,
77
RequestParams,
88
ModelProviderConfig,
9+
ResponseFormat,
910
} from "../types";
1011
import { CacheProvider } from "../../../common/cache/provider";
1112

@@ -67,4 +68,8 @@ export abstract class BaseProvider {
6768
return `Request failed with status ${response.status}`;
6869
}
6970
}
71+
72+
determineResponseFormat(endpoint: Endpoint): ResponseFormat {
73+
return "OPENAI";
74+
}
7075
}

packages/cost/models/providers/bedrock.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import type {
1010
RequestParams,
1111
ModelProviderConfig,
1212
UserEndpointConfig,
13+
ResponseFormat,
1314
} from "../types";
1415

1516
export class BedrockProvider extends BaseProvider {
@@ -121,4 +122,11 @@ export class BedrockProvider extends BaseProvider {
121122
stream_options: undefined,
122123
});
123124
}
125+
126+
determineResponseFormat(endpoint: Endpoint): ResponseFormat {
127+
if (endpoint.author === "anthropic" || endpoint.providerModelId.includes("claude-")) {
128+
return "ANTHROPIC";
129+
}
130+
return "OPENAI";
131+
}
124132
}

packages/cost/models/providers/vertex.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type {
55
AuthContext,
66
AuthResult,
77
RequestParams,
8+
ResponseFormat,
89
} from "../types";
910
import { getGoogleAccessToken } from "../../auth/gcpServiceAccountAuth";
1011
import { CacheProvider } from "../../../common/cache/provider";
@@ -125,4 +126,11 @@ export class VertexProvider extends BaseProvider {
125126
return `Request failed with status ${response.status}`;
126127
}
127128
}
129+
130+
determineResponseFormat(endpoint: Endpoint): ResponseFormat {
131+
if (endpoint.author === "anthropic" || endpoint.providerModelId.includes("claude-")) {
132+
return "ANTHROPIC";
133+
}
134+
return "OPENAI";
135+
}
128136
}

packages/cost/models/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ export interface Modality {
3838
outputs: OutputModality[];
3939
}
4040

41+
export type ResponseFormat = "ANTHROPIC" | "OPENAI";
42+
4143
export type Tokenizer =
4244
| "Claude"
4345
| "GPT"

worker/src/lib/ai-gateway/SimpleAIGateway.ts

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ import { AttemptExecutor } from "./AttemptExecutor";
1313
import { Attempt, DisallowListEntry, EscrowInfo } from "./types";
1414
import { oai2antResponse } from "../clients/llmmapper/router/oai2ant/nonStream";
1515
import { oai2antStreamResponse } from "../clients/llmmapper/router/oai2ant/stream";
16-
import { ModelProviderName } from "@helicone-package/cost/models/providers";
1716
import { RequestParams } from "@helicone-package/cost/models/types";
1817
import { SecureCacheProvider } from "../util/cache/secureCache";
18+
import { determineResponseFormat } from "@helicone-package/cost/models/provider-helpers";
1919

2020
export interface AuthContext {
2121
orgId: string;
@@ -303,22 +303,6 @@ export class SimpleAIGateway {
303303
);
304304
}
305305

306-
private determineResponseFormat(
307-
provider: ModelProviderName,
308-
modelId: string
309-
): "ANTHROPIC" | "OPENAI" {
310-
// TODO: make enum type when there's more map formats
311-
if (
312-
provider === "anthropic" ||
313-
(provider === "bedrock" && modelId.includes("claude-")) ||
314-
(provider === "vertex" && modelId.includes("claude-"))
315-
) {
316-
return "ANTHROPIC";
317-
}
318-
319-
return "OPENAI";
320-
}
321-
322306
private async mapResponse(
323307
attempt: Attempt,
324308
response: Response,
@@ -328,17 +312,18 @@ export class SimpleAIGateway {
328312
return ok(response); // do not map response
329313
}
330314

331-
const mappingType = this.determineResponseFormat(
332-
attempt.endpoint.provider,
333-
attempt.endpoint.providerModelId
334-
); // finds format of the response that we are mapping to OPENAI
315+
const mappingType = determineResponseFormat(attempt.endpoint);
316+
317+
if (isErr(mappingType)) {
318+
return err(`Failed to determine response format: ${mappingType.error}`);
319+
}
335320

336-
if (mappingType === "OPENAI") {
321+
if (mappingType.data === "OPENAI") {
337322
return ok(response); // already in OPENAI format
338323
}
339324

340325
try {
341-
if (mappingType === "ANTHROPIC") {
326+
if (mappingType.data === "ANTHROPIC") {
342327
const contentType = response.headers.get("content-type");
343328
const isStream = contentType?.includes("text/event-stream");
344329

0 commit comments

Comments
 (0)