Skip to content
This repository was archived by the owner on Oct 10, 2024. It is now read-only.

Commit ac6e138

Browse files
authored
Merge pull request #79 from mistralai/release/v0.4.0
release 0.4.0: add support for completion
2 parents 6d2639e + 7d8cf44 commit ac6e138

File tree

7 files changed

+219
-14
lines changed

7 files changed

+219
-14
lines changed

examples/json_format.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const apiKey = process.env.MISTRAL_API_KEY;
55
const client = new MistralClient(apiKey);
66

77
const chatResponse = await client.chat({
8-
model: 'mistral-large',
8+
model: 'mistral-large-latest',
99
messages: [{role: 'user', content: 'What is the best French cheese?'}],
1010
responseFormat: {type: 'json_object'},
1111
});

examples/package-lock.json

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@mistralai/mistralai",
3-
"version": "0.3.0",
3+
"version": "0.4.0",
44
"description": "",
55
"author": "[email protected]",
66
"license": "ISC",

src/client.d.ts

+22
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ declare module "@mistralai/mistralai" {
141141
responseFormat?: ResponseFormat;
142142
}
143143

144+
export interface CompletionRequest {
145+
model: string;
146+
prompt: string;
147+
suffix?: string;
148+
temperature?: number;
149+
maxTokens?: number;
150+
topP?: number;
151+
randomSeed?: number;
152+
stop?: string | string[];
153+
}
154+
144155
export interface ChatRequestOptions {
145156
signal?: AbortSignal;
146157
}
@@ -170,6 +181,17 @@ declare module "@mistralai/mistralai" {
170181
options?: ChatRequestOptions
171182
): AsyncGenerator<ChatCompletionResponseChunk, void>;
172183

184+
completion(
185+
request: CompletionRequest,
186+
options?: ChatRequestOptions
187+
): Promise<ChatCompletionResponse>;
188+
189+
completionStream(
190+
request: CompletionRequest,
191+
options?: ChatRequestOptions
192+
): AsyncGenerator<ChatCompletionResponseChunk, void>;
193+
194+
173195
embeddings(options: {
174196
model: string;
175197
input: string | string[];

src/client.js

+170-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class MistralClient {
161161
} else {
162162
throw new MistralAPIError(
163163
`HTTP error! status: ${response.status} ` +
164-
`Response: \n${await response.text()}`,
164+
`Response: \n${await response.text()}`,
165165
);
166166
}
167167
} catch (error) {
@@ -228,6 +228,47 @@ class MistralClient {
228228
};
229229
};
230230

231+
/**
232+
* Creates a completion request
233+
* @param {*} model
234+
* @param {*} prompt
235+
* @param {*} suffix
236+
* @param {*} temperature
237+
* @param {*} maxTokens
238+
* @param {*} topP
239+
* @param {*} randomSeed
240+
* @param {*} stop
241+
* @param {*} stream
242+
* @return {Promise<Object>}
243+
*/
244+
_makeCompletionRequest = function(
245+
model,
246+
prompt,
247+
suffix,
248+
temperature,
249+
maxTokens,
250+
topP,
251+
randomSeed,
252+
stop,
253+
stream,
254+
) {
255+
// if modelDefault and model are undefined, throw an error
256+
if (!model && !this.modelDefault) {
257+
throw new MistralAPIError('You must provide a model name');
258+
}
259+
return {
260+
model: model ?? this.modelDefault,
261+
prompt: prompt,
262+
suffix: suffix ?? undefined,
263+
temperature: temperature ?? undefined,
264+
max_tokens: maxTokens ?? undefined,
265+
top_p: topP ?? undefined,
266+
random_seed: randomSeed ?? undefined,
267+
stop: stop ?? undefined,
268+
stream: stream ?? undefined,
269+
};
270+
};
271+
231272
/**
232273
* Returns a list of the available models
233274
* @return {Promise<Object>}
@@ -401,6 +442,134 @@ class MistralClient {
401442
const response = await this._request('post', 'v1/embeddings', request);
402443
return response;
403444
};
445+
446+
/**
447+
* A completion endpoint without streaming.
448+
*
449+
* @param {Object} data - The main completion configuration.
450+
* @param {*} data.model - the name of the model to chat with,
451+
* e.g. mistral-tiny
452+
* @param {*} data.prompt - the prompt to complete,
453+
* e.g. 'def fibonacci(n: int):'
454+
* @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5
455+
* @param {*} data.maxTokens - the maximum number of tokens to generate,
456+
* e.g. 100
457+
* @param {*} data.topP - the cumulative probability of tokens to generate,
458+
* e.g. 0.9
459+
* @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42
460+
* @param {*} data.stop - the stop sequence to use, e.g. ['\n']
461+
* @param {*} data.suffix - the suffix to append to the prompt,
462+
* e.g. 'n = int(input(\'Enter a number: \'))'
463+
* @param {Object} options - Additional operational options.
464+
* @param {*} [options.signal] - optional AbortSignal instance to control
465+
* request The signal will be combined with
466+
* default timeout signal
467+
* @return {Promise<Object>}
468+
*/
469+
completion = async function(
470+
{
471+
model,
472+
prompt,
473+
suffix,
474+
temperature,
475+
maxTokens,
476+
topP,
477+
randomSeed,
478+
stop,
479+
},
480+
{signal} = {},
481+
) {
482+
const request = this._makeCompletionRequest(
483+
model,
484+
prompt,
485+
suffix,
486+
temperature,
487+
maxTokens,
488+
topP,
489+
randomSeed,
490+
stop,
491+
false,
492+
);
493+
const response = await this._request(
494+
'post',
495+
'v1/fim/completions',
496+
request,
497+
signal,
498+
);
499+
return response;
500+
};
501+
502+
/**
503+
* A completion endpoint that streams responses.
504+
*
505+
* @param {Object} data - The main completion configuration.
506+
* @param {*} data.model - the name of the model to chat with,
507+
* e.g. mistral-tiny
508+
* @param {*} data.prompt - the prompt to complete,
509+
* e.g. 'def fibonacci(n: int):'
510+
* @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5
511+
* @param {*} data.maxTokens - the maximum number of tokens to generate,
512+
* e.g. 100
513+
* @param {*} data.topP - the cumulative probability of tokens to generate,
514+
* e.g. 0.9
515+
* @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42
516+
* @param {*} data.stop - the stop sequence to use, e.g. ['\n']
517+
* @param {*} data.suffix - the suffix to append to the prompt,
518+
* e.g. 'n = int(input(\'Enter a number: \'))'
519+
* @param {Object} options - Additional operational options.
520+
* @param {*} [options.signal] - optional AbortSignal instance to control
521+
* request The signal will be combined with
522+
* default timeout signal
523+
* @return {Promise<Object>}
524+
*/
525+
completionStream = async function* (
526+
{
527+
model,
528+
prompt,
529+
suffix,
530+
temperature,
531+
maxTokens,
532+
topP,
533+
randomSeed,
534+
stop,
535+
},
536+
{signal} = {},
537+
) {
538+
const request = this._makeCompletionRequest(
539+
model,
540+
prompt,
541+
suffix,
542+
temperature,
543+
maxTokens,
544+
topP,
545+
randomSeed,
546+
stop,
547+
true,
548+
);
549+
const response = await this._request(
550+
'post',
551+
'v1/fim/completions',
552+
request,
553+
signal,
554+
);
555+
556+
let buffer = '';
557+
const decoder = new TextDecoder();
558+
for await (const chunk of response) {
559+
buffer += decoder.decode(chunk, {stream: true});
560+
let firstNewline;
561+
while ((firstNewline = buffer.indexOf('\n')) !== -1) {
562+
const chunkLine = buffer.substring(0, firstNewline);
563+
buffer = buffer.substring(firstNewline + 1);
564+
if (chunkLine.startsWith('data:')) {
565+
const json = chunkLine.substring(6).trim();
566+
if (json !== '[DONE]') {
567+
yield JSON.parse(json);
568+
}
569+
}
570+
}
571+
}
572+
};
404573
}
405574

406575
export default MistralClient;

tests/client.test.js

+20-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ describe('Mistral Client', () => {
2323
client._fetch = mockFetch(200, mockResponse);
2424

2525
const response = await client.chat({
26-
model: 'mistral-small',
26+
model: 'mistral-small-latest',
2727
messages: [
2828
{
2929
role: 'user',
@@ -40,7 +40,7 @@ describe('Mistral Client', () => {
4040
client._fetch = mockFetch(200, mockResponse);
4141

4242
const response = await client.chat({
43-
model: 'mistral-small',
43+
model: 'mistral-small-latest',
4444
messages: [
4545
{
4646
role: 'user',
@@ -58,7 +58,7 @@ describe('Mistral Client', () => {
5858
client._fetch = mockFetch(200, mockResponse);
5959

6060
const response = await client.chat({
61-
model: 'mistral-small',
61+
model: 'mistral-small-latest',
6262
messages: [
6363
{
6464
role: 'user',
@@ -78,7 +78,7 @@ describe('Mistral Client', () => {
7878
client._fetch = mockFetchStream(200, mockResponse);
7979

8080
const response = await client.chatStream({
81-
model: 'mistral-small',
81+
model: 'mistral-small-latest',
8282
messages: [
8383
{
8484
role: 'user',
@@ -101,7 +101,7 @@ describe('Mistral Client', () => {
101101
client._fetch = mockFetchStream(200, mockResponse);
102102

103103
const response = await client.chatStream({
104-
model: 'mistral-small',
104+
model: 'mistral-small-latest',
105105
messages: [
106106
{
107107
role: 'user',
@@ -125,7 +125,7 @@ describe('Mistral Client', () => {
125125
client._fetch = mockFetchStream(200, mockResponse);
126126

127127
const response = await client.chatStream({
128-
model: 'mistral-small',
128+
model: 'mistral-small-latest',
129129
messages: [
130130
{
131131
role: 'user',
@@ -176,4 +176,18 @@ describe('Mistral Client', () => {
176176
expect(response).toEqual(mockResponse);
177177
});
178178
});
179+
180+
describe('completion()', () => {
181+
it('should return a chat response object', async() => {
182+
// Mock the fetch function
183+
const mockResponse = mockChatResponsePayload();
184+
client._fetch = mockFetch(200, mockResponse);
185+
186+
const response = await client.completion({
187+
model: 'mistral-small-latest',
188+
prompt: '# this is a',
189+
});
190+
expect(response).toEqual(mockResponse);
191+
});
192+
});
179193
});

tests/utils.js

+4-4
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ export function mockListModels() {
7878
],
7979
},
8080
{
81-
id: 'mistral-small',
81+
id: 'mistral-small-latest',
8282
object: 'model',
8383
created: 1703186988,
8484
owned_by: 'mistralai',
@@ -172,7 +172,7 @@ export function mockChatResponsePayload() {
172172
index: 0,
173173
},
174174
],
175-
model: 'mistral-small',
175+
model: 'mistral-small-latest',
176176
usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0},
177177
};
178178
}
@@ -187,7 +187,7 @@ export function mockChatResponseStreamingPayload() {
187187
[encoder.encode('data: ' +
188188
JSON.stringify({
189189
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
190-
model: 'mistral-small',
190+
model: 'mistral-small-latest',
191191
choices: [
192192
{
193193
index: 0,
@@ -207,7 +207,7 @@ export function mockChatResponseStreamingPayload() {
207207
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
208208
object: 'chat.completion.chunk',
209209
created: 1703168544,
210-
model: 'mistral-small',
210+
model: 'mistral-small-latest',
211211
choices: [
212212
{
213213
index: i,

0 commit comments

Comments
 (0)