diff --git a/packages/vertex-sdk/examples/web.ts b/packages/vertex-sdk/examples/web.ts new file mode 100644 index 00000000..2060a147 --- /dev/null +++ b/packages/vertex-sdk/examples/web.ts @@ -0,0 +1,31 @@ +#!/usr/bin/env -S npm run tsn -T + +import { AnthropicVertexWeb } from '@anthropic-ai/vertex-sdk/web'; + +// Reads from the `CLOUD_ML_REGION` & `ANTHROPIC_VERTEX_PROJECT_ID` +// environment variables. +const client = new AnthropicVertexWeb({ + region: '', + projectId: '', + clientEmail: '', + privateKey: '', +}); + +async function main() { + const result = await client.messages.create({ + messages: [ + { + role: 'user', + content: 'Hello!', + }, + ], + model: 'claude-3-sonnet@20240229', + max_tokens: 300, + }); + console.log(JSON.stringify(result, null, 2)); +} + +main().catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/packages/vertex-sdk/package.json b/packages/vertex-sdk/package.json index 6e3c3c53..68d6e4a9 100644 --- a/packages/vertex-sdk/package.json +++ b/packages/vertex-sdk/package.json @@ -23,7 +23,8 @@ }, "dependencies": { "@anthropic-ai/sdk": "file:../../dist/", - "google-auth-library": "^9.4.2" + "google-auth-library": "^9.4.2", + "jose": "^5.7.0" }, "devDependencies": { "@types/jest": "^29.4.0", @@ -54,6 +55,14 @@ "types": "./dist/index.d.mts", "default": "./dist/index.mjs" }, + "./web": { + "require": { + "types": "./dist/web.d.ts", + "default": "./dist/web.js" + }, + "types": "./dist/web.d.ts", + "default": "./dist/web.mjs" + }, "./*.mjs": { "types": "./dist/*.d.ts", "default": "./dist/*.mjs" diff --git a/packages/vertex-sdk/src/index.ts b/packages/vertex-sdk/src/index.ts index 2853ecec..5c81374d 100644 --- a/packages/vertex-sdk/src/index.ts +++ b/packages/vertex-sdk/src/index.ts @@ -1 +1,2 @@ export { AnthropicVertex } from './client'; +export { AnthropicVertexWeb } from './web'; diff --git a/packages/vertex-sdk/src/utils/authenticate.ts b/packages/vertex-sdk/src/utils/authenticate.ts new file mode 100644 index 00000000..ff3183df --- /dev/null +++ b/packages/vertex-sdk/src/utils/authenticate.ts @@ -0,0 +1,62 @@ +import { SignJWT, importPKCS8 } from 'jose'; + +type Token = { + access_token: string; + expires_in: number; + token_type: string; +}; + +type TokenWithExpiration = Token & { + expires_at: number; +}; + +let token: TokenWithExpiration | null = null; + +async function createToken(options: { clientEmail: string; privateKey: string }) { + const rawPrivateKey = options.privateKey.replace(/\\n/g, '\n'); + const privateKey = await importPKCS8(rawPrivateKey, 'RS256'); + + const payload = { + iss: options.clientEmail, + scope: 'https://www.googleapis.com/auth/cloud-platform', + aud: 'https://www.googleapis.com/oauth2/v4/token', + exp: Math.floor(Date.now() / 1000) + 60 * 60, + iat: Math.floor(Date.now() / 1000), + }; + const token = await new SignJWT(payload) + .setProtectedHeader({ alg: 'RS256' }) + .setIssuedAt() + .setIssuer(options.clientEmail) + .setAudience('https://www.googleapis.com/oauth2/v4/token') + .setExpirationTime('1h') + .sign(privateKey); + + // Form data for the token request + const form = { + grant_type: 'urn:ietf:params:oauth:grant-type:jwt-bearer', + assertion: token, + }; + + // Make the token request + const tokenResponse = await fetch('https://www.googleapis.com/oauth2/v4/token', { + method: 'POST', + body: JSON.stringify(form), + headers: { 'Content-Type': 'application/json' }, + }); + + const json = (await tokenResponse.json()) as Token; + + return { + ...json, + expires_at: Math.floor(Date.now() / 1000) + json.expires_in, + }; +} + +export async function authenticate(options: { clientEmail: string; privateKey: string }): Promise { + if (token === null) { + token = await createToken(options); + } else if (token.expires_at < Math.floor(Date.now() / 1000)) { + token = await createToken(options); + } + return token; +} diff --git a/packages/vertex-sdk/src/web.ts b/packages/vertex-sdk/src/web.ts new file mode 100644 index 00000000..38e61f2c --- /dev/null +++ b/packages/vertex-sdk/src/web.ts @@ -0,0 +1,129 @@ +import * as Core from '@anthropic-ai/sdk/core'; +import * as Resources from '@anthropic-ai/sdk/resources/index'; +import * as API from '@anthropic-ai/sdk/index'; +import { type RequestInit } from '@anthropic-ai/sdk/_shims/index'; +import { authenticate } from './utils/authenticate'; + +const DEFAULT_VERSION = 'vertex-2023-10-16'; + +export type ClientOptions = Omit & { + region?: string | null | undefined; + projectId?: string | null | undefined; + accessToken?: string | null | undefined; + clientEmail?: string | null | undefined; + privateKey?: string | null | undefined; +}; + +export class AnthropicVertexWeb extends Core.APIClient { + region: string; + projectId: string | null; + accessToken: string | null; + + private _options: ClientOptions; + + constructor({ + baseURL = Core.readEnv('ANTHROPIC_VERTEX_BASE_URL'), + region = Core.readEnv('CLOUD_ML_REGION') ?? null, + projectId = Core.readEnv('ANTHROPIC_VERTEX_PROJECT_ID') ?? null, + accessToken = Core.readEnv('ANTHROPIC_VERTEX_ACCESS_TOKEN') ?? null, + clientEmail = Core.readEnv('ANTHROPIC_VERTEX_CLIENT_EMAIL') ?? null, + privateKey = Core.readEnv('ANTHROPIC_VERTEX_PRIVATE_KEY') ?? null, + ...opts + }: ClientOptions = {}) { + if (!region) { + throw new Error( + 'No region was given. The client should be instantiated with the `region` option or the `CLOUD_ML_REGION` environment variable should be set.', + ); + } + + const options: ClientOptions = { + ...opts, + baseURL: baseURL || `https://${region}-aiplatform.googleapis.com/v1`, + clientEmail, + privateKey, + }; + + super({ + baseURL: options.baseURL!, + timeout: options.timeout ?? 600000 /* 10 minutes */, + httpAgent: options.httpAgent, + maxRetries: options.maxRetries, + fetch: options.fetch, + }); + this._options = options; + + this.region = region; + this.projectId = projectId; + this.accessToken = accessToken; + } + + messages: Resources.Messages = new Resources.Messages(this); + + protected override defaultQuery(): Core.DefaultQuery | undefined { + return this._options.defaultQuery; + } + + protected override defaultHeaders(opts: Core.FinalRequestOptions): Core.Headers { + return { + ...super.defaultHeaders(opts), + ...this._options.defaultHeaders, + }; + } + + protected override async prepareOptions(options: Core.FinalRequestOptions): Promise { + if (!this.accessToken) { + if (!this._options.clientEmail || !this._options.privateKey) { + throw new Error( + 'No clientEmail or privateKey was provided. Set it in the constructor or use the ANTHROPIC_VERTEX_CLIENT_EMAIL and ANTHROPIC_VERTEX_PRIVATE_KEY environment variables.', + ); + } + this.accessToken = ( + await authenticate({ + clientEmail: this._options.clientEmail, + privateKey: this._options.privateKey, + }) + ).access_token; + } + + options.headers = { + ...options.headers, + Authorization: `Bearer ${this.accessToken}`, + 'x-goog-user-project': this.projectId, + }; + } + + override buildRequest(options: Core.FinalRequestOptions): { + req: RequestInit; + url: string; + timeout: number; + } { + if (Core.isObj(options.body)) { + if (!options.body['anthropic_version']) { + options.body['anthropic_version'] = DEFAULT_VERSION; + } + } + + if (options.path === '/v1/messages' && options.method === 'post') { + if (!this.projectId) { + throw new Error( + 'No projectId was given and it could not be resolved from credentials. The client should be instantiated with the `projectId` option or the `ANTHROPIC_VERTEX_PROJECT_ID` environment variable should be set.', + ); + } + + if (!Core.isObj(options.body)) { + throw new Error('Expected request body to be an object for post /v1/messages'); + } + + const model = options.body['model']; + options.body['model'] = undefined; + + const stream = options.body['stream'] ?? false; + + const specifier = stream ? 'streamRawPredict' : 'rawPredict'; + + options.path = `/projects/${this.projectId}/locations/${this.region}/publishers/anthropic/models/${model}:${specifier}`; + } + + return super.buildRequest(options); + } +} diff --git a/packages/vertex-sdk/yarn.lock b/packages/vertex-sdk/yarn.lock index 23194603..7158dd0f 100644 --- a/packages/vertex-sdk/yarn.lock +++ b/packages/vertex-sdk/yarn.lock @@ -16,9 +16,7 @@ "@jridgewell/trace-mapping" "^0.3.9" "@anthropic-ai/sdk@file:../../dist": - # x-release-please-start-version version "0.27.0" - # x-release-please-end-version dependencies: "@types/node" "^18.11.18" "@types/node-fetch" "^2.6.4" @@ -27,7 +25,6 @@ form-data-encoder "1.7.2" formdata-node "^4.3.2" node-fetch "^2.6.7" - web-streams-polyfill "^3.2.1" "@babel/code-frame@^7.0.0", "@babel/code-frame@^7.12.13", "@babel/code-frame@^7.22.13", "@babel/code-frame@^7.23.5": version "7.23.5" @@ -2365,6 +2362,11 @@ jest@^29.4.0: import-local "^3.0.2" jest-cli "^29.7.0" +jose@^5.7.0: + version "5.7.0" + resolved "https://registry.yarnpkg.com/jose/-/jose-5.7.0.tgz#5c3a6eb811235e692e1af4891904b7b91b204f57" + integrity sha512-3P9qfTYDVnNn642LCAqIKbTGb9a1TBxZ9ti5zEVEr48aDdflgRjhspWFb6WM4PzAfFbGMJYC4+803v8riCRAKw== + js-tokens@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499" @@ -3257,11 +3259,6 @@ web-streams-polyfill@4.0.0-beta.3: resolved "https://registry.yarnpkg.com/web-streams-polyfill/-/web-streams-polyfill-4.0.0-beta.3.tgz#2898486b74f5156095e473efe989dcf185047a38" integrity sha512-QW95TCTaHmsYfHDybGMwO5IJIM93I/6vTRk+daHTWFPhwh+C8Cg7j7XyKrwrj8Ib6vYXe0ocYNrmzY4xAAN6ug== -web-streams-polyfill@^3.2.1: - version "3.3.2" - resolved "https://registry.yarnpkg.com/web-streams-polyfill/-/web-streams-polyfill-3.3.2.tgz#32e26522e05128203a7de59519be3c648004343b" - integrity sha512-3pRGuxRF5gpuZc0W+EpwQRmCD7gRqcDOMt688KmdlDAgAyaB1XlN0zq2njfDNm44XVdIouE7pZ6GzbdyH47uIQ== - webidl-conversions@^3.0.0: version "3.0.1" resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"