Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: add amazon nova support (text) only (WIP) #10021

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ internal IBedrockTextGenerationService CreateTextGenerationService(string modelI
case "AMAZON":
if (modelName.StartsWith("titan-", StringComparison.OrdinalIgnoreCase))
{
return new AmazonService();
return new AmazonTitanService();
}

if (modelName.StartsWith("nova-", StringComparison.OrdinalIgnoreCase))
{
return new AmazonNovaService();
}
throw new NotSupportedException($"Unsupported Amazon model: {modelId}");
case "ANTHROPIC":
Expand Down Expand Up @@ -92,7 +97,11 @@ internal IBedrockChatCompletionService CreateChatCompletionService(string modelI
case "AMAZON":
if (modelName.StartsWith("titan-", StringComparison.OrdinalIgnoreCase))
{
return new AmazonService();
return new AmazonTitanService();
}
if (modelName.StartsWith("nova-", StringComparison.OrdinalIgnoreCase))
{
return new AmazonNovaService();
}
throw new NotSupportedException($"Unsupported Amazon model: {modelId}");
case "ANTHROPIC":
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.IO;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime.Model;
using Amazon.Runtime.Documents;
using Microsoft.SemanticKernel.ChatCompletion;
using static Microsoft.SemanticKernel.Connectors.Amazon.Core.NovaRequest;

namespace Microsoft.SemanticKernel.Connectors.Amazon.Core;

internal sealed class AmazonNovaService : IBedrockTextGenerationService, IBedrockChatCompletionService
{
private static readonly JsonSerializerOptions s_jsonSerializerOptions = new()
{ PropertyNameCaseInsensitive = true };

ConverseRequest IBedrockChatCompletionService.GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings)
{
var messages = BedrockModelUtilities.BuildMessageList(chatHistory);
var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory);

var executionSettings = AmazonNovaExecutionSettings.FromExecutionSettings(settings);
var schemaVersion = BedrockModelUtilities.GetExtensionDataValue<string?>(settings?.ExtensionData, "schemaVersion") ?? executionSettings.SchemaVersion;
var maxNewTokens = BedrockModelUtilities.GetExtensionDataValue<int?>(settings?.ExtensionData, "max_new_tokens") ?? executionSettings.MaxNewTokens;
var topP = BedrockModelUtilities.GetExtensionDataValue<float?>(settings?.ExtensionData, "top_p") ?? executionSettings.TopP;
var topK = BedrockModelUtilities.GetExtensionDataValue<int?>(settings?.ExtensionData, "top_k") ?? executionSettings.TopK;
var temperature = BedrockModelUtilities.GetExtensionDataValue<float?>(settings?.ExtensionData, "temperature") ?? executionSettings.Temperature;
var stopSequences = BedrockModelUtilities.GetExtensionDataValue<List<string>?>(settings?.ExtensionData, "stopSequences") ?? executionSettings.StopSequences;

var inferenceConfig = new InferenceConfiguration();
BedrockModelUtilities.SetPropertyIfNotNull(() => temperature, value => inferenceConfig.Temperature = value);
BedrockModelUtilities.SetPropertyIfNotNull(() => topP, value => inferenceConfig.TopP = value);
BedrockModelUtilities.SetPropertyIfNotNull(() => maxNewTokens, value => inferenceConfig.MaxTokens = value);
BedrockModelUtilities.SetNullablePropertyIfNotNull(() => stopSequences, value => inferenceConfig.StopSequences = value);

var converseRequest = new ConverseRequest
{
ModelId = modelId,
Messages = messages,
System = systemMessages,
InferenceConfig = inferenceConfig,
AdditionalModelRequestFields = new Document(),
AdditionalModelResponseFieldPaths = new List<string>()
};

return converseRequest;
}

ConverseStreamRequest IBedrockChatCompletionService.GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings)
{
throw new System.NotImplementedException();
}

object IBedrockTextGenerationService.GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings)
{
var settings = AmazonNovaExecutionSettings.FromExecutionSettings(executionSettings);
var schemaVersion = BedrockModelUtilities.GetExtensionDataValue<string?>(executionSettings?.ExtensionData, "schemaVersion") ?? settings.SchemaVersion;
var maxNewTokens = BedrockModelUtilities.GetExtensionDataValue<int?>(executionSettings?.ExtensionData, "max_new_tokens") ?? settings.MaxNewTokens;
var topP = BedrockModelUtilities.GetExtensionDataValue<float?>(executionSettings?.ExtensionData, "top_p") ?? settings.TopP;
var topK = BedrockModelUtilities.GetExtensionDataValue<int?>(executionSettings?.ExtensionData, "top_k") ?? settings.TopK;
var temperature = BedrockModelUtilities.GetExtensionDataValue<float?>(executionSettings?.ExtensionData, "temperature") ?? settings.Temperature;
var stopSequences = BedrockModelUtilities.GetExtensionDataValue<IList<string>?>(executionSettings?.ExtensionData, "stopSequences") ?? settings.StopSequences;

var requestBody = new NovaRequest.NovaTextGenerationRequest()
{
InferenceConfig = new NovaRequest.NovaTextGenerationConfig
{
MaxNewTokens = maxNewTokens,
Temperature = temperature,
TopK = topK,
TopP = topP
},
Messages = new List<NovaRequest.NovaUserMessage> { new() { Role = AuthorRole.User.Label, Content = new List<NovaUserMessageContent> { new() { Text = prompt } } } },
SchemaVersion = schemaVersion ?? "messages-v1",
};
return requestBody;
}

IReadOnlyList<TextContent> IBedrockTextGenerationService.GetInvokeResponseBody(InvokeModelResponse response)
{
using var reader = new StreamReader(response.Body);
var responseBody = JsonSerializer.Deserialize<NovaTextResponse>(reader.ReadToEnd(), s_jsonSerializerOptions);
List<TextContent> textContents = [];
if (responseBody?.Output?.Message?.Contents is not { Count: > 0 })
{
return textContents;
}
string? outputText = responseBody.Output.Message.Contents[0].Text;
return [new TextContent(outputText, innerContent: responseBody)];
}

IEnumerable<StreamingTextContent> IBedrockTextGenerationService.GetTextStreamOutput(JsonNode chunk)
{
var text = chunk["output"]?["message"]?["content"]?["text"]?.ToString();
if (!string.IsNullOrEmpty(text))
{
yield return new StreamingTextContent(text, innerContent: chunk)!;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.Amazon.Core;

internal static class NovaRequest
{
/// <summary>
/// The Nova Text Generation Request object.
/// </summary>
internal sealed class NovaTextGenerationRequest
{
/// <summary>
/// Schema version for the request, defaulting to "messages-v1".
/// </summary>
[JsonPropertyName("schemaVersion")]
public string SchemaVersion { get; set; } = "messages-v1";

/// <summary>
/// System messages providing context for the generation.
/// </summary>
[JsonPropertyName("system")]
public IList<NovaSystemMessage>? System { get; set; }

/// <summary>
/// User messages for text generation.
/// </summary>
[JsonPropertyName("messages")]
public IList<NovaUserMessage>? Messages { get; set; }

/// <summary>
/// Text generation configurations as required by Nova request body.
/// </summary>
[JsonPropertyName("inferenceConfig")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public NovaTextGenerationConfig? InferenceConfig { get; set; }
}

/// <summary>
/// Represents a system message.
/// </summary>
internal sealed class NovaSystemMessage
{
/// <summary>
/// The text of the system message.
/// </summary>
[JsonPropertyName("text")]
public string? Text { get; set; }
}

/// <summary>
/// Represents a user message.
/// </summary>
internal sealed class NovaUserMessage
{
/// <summary>
/// The role of the message sender.
/// </summary>
[JsonPropertyName("role")]
public string? Role { get; set; }

/// <summary>
/// The content of the user message.
/// </summary>
[JsonPropertyName("content")]
public IList<NovaUserMessageContent>? Content { get; set; } = new List<NovaUserMessageContent>();
}

/// <summary>
/// Represents the content of a user message.
/// </summary>
internal sealed class NovaUserMessageContent
{
/// <summary>
/// The text of the user message content.
/// </summary>
[JsonPropertyName("text")]
public string? Text { get; set; }
}

/// <summary>
/// Nova Text Generation Configurations.
/// </summary>
internal sealed class NovaTextGenerationConfig
{
/// <summary>
/// Maximum new tokens to generate in the response.
/// </summary>
[JsonPropertyName("max_new_tokens")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? MaxNewTokens { get; set; }

/// <summary>
/// Top P controls token choices, based on the probability of the potential choices. The range is 0 to 1. The default is 1.
/// </summary>
[JsonPropertyName("top_p")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public float? TopP { get; set; }

/// <summary>
/// Top K limits the number of token options considered at each generation step. The default is 20.
/// </summary>
[JsonPropertyName("top_k")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? TopK { get; set; }

/// <summary>
/// The Temperature value ranges from 0 to 1, with 0 being the most deterministic and 1 being the most creative.
/// </summary>
[JsonPropertyName("temperature")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public float? Temperature { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.Amazon.Core;

internal sealed class NovaMessage
{
internal sealed class Content
{
public string? Text { get; set; }
}

[JsonPropertyName("content")]
public List<Content>? Contents { get; set; }

public string? Role { get; set; }
}

internal sealed class Output
{
public NovaMessage? Message { get; set; }
}

internal sealed class Usage
{
public int InputTokens { get; set; }

public int OutputTokens { get; set; }

public int TotalTokens { get; set; }
}

/// <summary>
/// The Amazon Titan Text response object when deserialized from Invoke Model call.
/// </summary>
internal sealed class NovaTextResponse
{
public Output? Output { get; set; }

public Usage? Usage { get; set; }

public string? StopReason { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Microsoft.SemanticKernel.Connectors.Amazon.Core;
/// <summary>
/// Input-output service for Amazon Titan model.
/// </summary>
internal sealed class AmazonService : IBedrockTextGenerationService, IBedrockChatCompletionService
internal sealed class AmazonTitanService : IBedrockTextGenerationService, IBedrockChatCompletionService
{
/// <inheritdoc/>
public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings)
Expand Down
Loading
Loading