Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vertexai): Add repetition penalties to GenerationConfig #17234

Merged
merged 7 commits into from
Mar 28, 2025
43 changes: 43 additions & 0 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,8 @@ abstract class BaseGenerationConfig {
this.temperature,
this.topP,
this.topK,
this.presencePenalty,
this.frequencyPenalty,
});

/// Number of generated responses to return.
Expand Down Expand Up @@ -700,6 +702,41 @@ abstract class BaseGenerationConfig {
/// Note: The default value varies by model.
final int? topK;

/// The penalty for repeating the same words or phrases already generated in
/// the text.
///
/// Controls the likelihood of repetition. Higher penalty values result in
/// more diverse output.
///
/// **Note:** While both [presencePenalty] and [frequencyPenalty] discourage
/// repetition, [presencePenalty] applies the same penalty regardless of how
/// many times the word/phrase has already appeared, whereas
/// [frequencyPenalty] increases the penalty for *each* repetition of a
/// word/phrase.
///
/// **Important:** The range of supported [presencePenalty] values depends on
/// the model; see the
/// [documentation](https://firebase.google.com/docs/vertex-ai/model-parameters?platform=flutter#configure-model-parameters-gemini)
/// for more details.
final double? presencePenalty;

/// The penalty for repeating words or phrases, with the penalty increasing
/// for each repetition.
///
/// Controls the likelihood of repetition. Higher values increase the penalty
/// of repetition, resulting in more diverse output.
///
/// **Note:** While both [frequencyPenalty] and [presencePenalty] discourage
/// repetition, [frequencyPenalty] increases the penalty for *each* repetition
/// of a word/phrase, whereas [presencePenalty] applies the same penalty
/// regardless of how many times the word/phrase has already appeared.
///
/// **Important:** The range of supported [frequencyPenalty] values depends on
/// the model; see the
/// [documentation](https://firebase.google.com/docs/vertex-ai/model-parameters?platform=flutter#configure-model-parameters-gemini)
/// for more details.
final double? frequencyPenalty;

// ignore: public_member_api_docs
Map<String, Object?> toJson() => {
if (candidateCount case final candidateCount?)
Expand All @@ -709,6 +746,10 @@ abstract class BaseGenerationConfig {
if (temperature case final temperature?) 'temperature': temperature,
if (topP case final topP?) 'topP': topP,
if (topK case final topK?) 'topK': topK,
if (presencePenalty case final presencePenalty?)
'presencePenalty': presencePenalty,
if (frequencyPenalty case final frequencyPenalty?)
'frequencyPenalty': frequencyPenalty,
};
}

Expand All @@ -722,6 +763,8 @@ final class GenerationConfig extends BaseGenerationConfig {
super.temperature,
super.topP,
super.topK,
super.presencePenalty,
super.frequencyPenalty,
this.responseMimeType,
this.responseSchema,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ final class LiveGenerationConfig extends BaseGenerationConfig {
super.temperature,
super.topP,
super.topK,
super.presencePenalty,
super.frequencyPenalty,
});

/// The speech configuration.
Expand Down
17 changes: 17 additions & 0 deletions packages/firebase_vertexai/firebase_vertexai/test/model_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ void main() {
);
});

test('can override GenerationConfig repetition penalties', () async {
final (client, model) = createModel();
const prompt = 'Some prompt';
await client.checkRequest(
() => model.generateContent([Content.text(prompt)],
generationConfig: GenerationConfig(
presencePenalty: 0.5, frequencyPenalty: 0.2)),
verifyRequest: (_, request) {
expect(request['generationConfig'], {
'presencePenalty': 0.5,
'frequencyPenalty': 0.2,
});
},
response: arbitraryGenerateContentResponse,
);
});

test('can pass system instructions', () async {
const instructions = 'Do a good job';
final (client, model) = createModel(
Expand Down
Loading