Skip to content

Commit 23681e3

Browse files
committed
fix(openai): merge custom-provider system prompts
Signed-off-by: pandego <7780875+pandego@users.noreply.github.com>
1 parent 170d4d7 commit 23681e3

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

pkg/model/provider/openai/client.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,16 @@ func (c *Client) Close() {
173173
}
174174
}
175175

176-
// convertMessages converts chat.Message to openai.ChatCompletionMessageParamUnion
177-
// using the shared oaistream implementation.
178-
func convertMessages(messages []chat.Message) []openai.ChatCompletionMessageParamUnion {
179-
return oaistream.ConvertMessages(messages)
176+
// convertMessages converts chat.Message to OpenAI chat-completions message params.
177+
// Custom OpenAI-compatible providers may target local model runners that reject
178+
// consecutive system or user messages, so we normalize those prompts to match
179+
// the DMR provider behavior.
180+
func convertMessages(cfg *latest.ModelConfig, messages []chat.Message) []openai.ChatCompletionMessageParamUnion {
181+
openaiMessages := oaistream.ConvertMessages(messages)
182+
if isCustomProvider(cfg) {
183+
return oaistream.MergeConsecutiveMessages(openaiMessages)
184+
}
185+
return openaiMessages
180186
}
181187

182188
// CreateChatCompletionStream creates a streaming chat completion request
@@ -220,7 +226,7 @@ func (c *Client) CreateChatCompletionStream(
220226

221227
params := openai.ChatCompletionNewParams{
222228
Model: c.ModelConfig.Model,
223-
Messages: convertMessages(messages),
229+
Messages: convertMessages(&c.ModelConfig, messages),
224230
StreamOptions: openai.ChatCompletionStreamOptionsParam{
225231
IncludeUsage: openai.Bool(trackUsage),
226232
},

pkg/model/provider/openai/client_test.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/stretchr/testify/require"
88

99
"github.com/docker/docker-agent/pkg/chat"
10+
"github.com/docker/docker-agent/pkg/config/latest"
1011
"github.com/docker/docker-agent/pkg/tools"
1112
)
1213

@@ -47,7 +48,7 @@ func TestConvertMessagesToResponseInput_OrphanedFunctionCall(t *testing.T) {
4748
}
4849

4950
func TestConvertMessagesToResponseInput_NoOrphans(t *testing.T) {
50-
// All tool calls have matching results no placeholder needed.
51+
// All tool calls have matching results - no placeholder needed.
5152
messages := []chat.Message{
5253
{Role: chat.MessageRoleUser, Content: "hello"},
5354
{
@@ -69,3 +70,37 @@ func TestConvertMessagesToResponseInput_NoOrphans(t *testing.T) {
6970
}
7071
assert.Equal(t, 1, outputCount, "should not inject extra outputs when all calls have results")
7172
}
73+
74+
func TestConvertMessages_MergesConsecutiveSystemMessagesForCustomProviders(t *testing.T) {
75+
cfg := &latest.ModelConfig{
76+
ProviderOpts: map[string]any{"api_type": "openai_chatcompletions"},
77+
}
78+
messages := []chat.Message{
79+
{Role: chat.MessageRoleSystem, Content: "You are Bob, a coding expert"},
80+
{Role: chat.MessageRoleSystem, Content: "## Custom Shell Tools\n\n### execute_command"},
81+
{Role: chat.MessageRoleSystem, Content: "<available_skills>\n <skill>what-time-is-it</skill>\n</available_skills>"},
82+
{Role: chat.MessageRoleUser, Content: "what is your favourite colour?"},
83+
}
84+
85+
result := convertMessages(cfg, messages)
86+
require.Len(t, result, 2)
87+
require.NotNil(t, result[0].OfSystem)
88+
assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "You are Bob, a coding expert")
89+
assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "Custom Shell Tools")
90+
assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "available_skills")
91+
assert.NotNil(t, result[1].OfUser)
92+
}
93+
94+
func TestConvertMessages_PreservesConsecutiveSystemMessagesForOpenAIProvider(t *testing.T) {
95+
messages := []chat.Message{
96+
{Role: chat.MessageRoleSystem, Content: "System 1"},
97+
{Role: chat.MessageRoleSystem, Content: "System 2"},
98+
{Role: chat.MessageRoleUser, Content: "hello"},
99+
}
100+
101+
result := convertMessages(&latest.ModelConfig{}, messages)
102+
require.Len(t, result, 3)
103+
assert.NotNil(t, result[0].OfSystem)
104+
assert.NotNil(t, result[1].OfSystem)
105+
assert.NotNil(t, result[2].OfUser)
106+
}

0 commit comments

Comments
 (0)