Skip to content

Commit 7677537

Browse files
authored
feat(adk): add SummarizeMessages for on-demand synchronous summarization (#958)
1 parent 0b607bf commit 7677537

2 files changed

Lines changed: 373 additions & 0 deletions

File tree

adk/middlewares/summarization/summarization.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,57 @@ type middleware struct {
266266
cfg *Config
267267
}
268268

269+
// SummarizeOutput contains the output of a synchronous Summarize call.
270+
type SummarizeOutput struct {
271+
// FinalizedMessages is the message list after summarization,
272+
// ready to be used as the new conversation history.
273+
FinalizedMessages []adk.Message
274+
275+
// ModelResponse is the raw response from the summarization model.
276+
ModelResponse adk.Message
277+
}
278+
279+
// SummarizeMessages performs synchronous summarization of the given messages.
280+
// EmitInternalEvents and Trigger are not supported and will return an error if set.
281+
func SummarizeMessages(ctx context.Context, cfg *Config, messages []adk.Message) (*SummarizeOutput, error) {
282+
if cfg.EmitInternalEvents {
283+
return nil, fmt.Errorf("emitInternalEvents is not supported in synchronous summarization")
284+
}
285+
if cfg.Trigger != nil {
286+
return nil, fmt.Errorf("trigger is not supported in synchronous summarization")
287+
}
288+
if err := cfg.check(); err != nil {
289+
return nil, err
290+
}
291+
292+
m := &middleware{cfg: cfg}
293+
294+
rawSummary, modelInput, err := m.summarize(ctx, messages)
295+
if err != nil {
296+
return nil, err
297+
}
298+
299+
ctx = context.WithValue(ctx, ctxKeyModelInput{}, modelInput)
300+
301+
_, finalMsgs, err := m.finalizeSummary(ctx, messages, rawSummary)
302+
if err != nil {
303+
return nil, err
304+
}
305+
306+
if m.cfg.Callback != nil {
307+
beforeState := adk.ChatModelAgentState{Messages: messages}
308+
afterState := adk.ChatModelAgentState{Messages: finalMsgs}
309+
if err = m.cfg.Callback(ctx, beforeState, afterState); err != nil {
310+
return nil, err
311+
}
312+
}
313+
314+
return &SummarizeOutput{
315+
FinalizedMessages: finalMsgs,
316+
ModelResponse: rawSummary,
317+
}, nil
318+
}
319+
269320
func (m *middleware) BeforeModelRewriteState(ctx context.Context, state *adk.ChatModelAgentState,
270321
mtx *adk.ModelContext) (context.Context, *adk.ChatModelAgentState, error) {
271322

adk/middlewares/summarization/summarization_test.go

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,3 +1594,325 @@ func TestHelperBranches(t *testing.T) {
15941594
assert.Contains(t, []string{userMessagesReplacedNote, userMessagesReplacedNoteZh}, note)
15951595
})
15961596
}
1597+
1598+
func TestSummarizeMessages(t *testing.T) {
1599+
ctx := context.Background()
1600+
1601+
t.Run("basic summarization", func(t *testing.T) {
1602+
ctrl := gomock.NewController(t)
1603+
cm := mockModel.NewMockBaseChatModel(ctrl)
1604+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1605+
Return(&schema.Message{
1606+
Role: schema.Assistant,
1607+
Content: "Summary content",
1608+
}, nil).Times(1)
1609+
1610+
cfg := &Config{
1611+
Model: cm,
1612+
}
1613+
1614+
messages := []adk.Message{
1615+
schema.SystemMessage("You are a helpful assistant"),
1616+
schema.UserMessage(strings.Repeat("a", 100)),
1617+
schema.AssistantMessage(strings.Repeat("b", 100), nil),
1618+
}
1619+
1620+
output, err := SummarizeMessages(ctx, cfg, messages)
1621+
assert.NoError(t, err)
1622+
assert.NotNil(t, output)
1623+
assert.NotNil(t, output.ModelResponse)
1624+
assert.Equal(t, "Summary content", output.ModelResponse.Content)
1625+
assert.NotEmpty(t, output.FinalizedMessages)
1626+
assert.Equal(t, schema.System, output.FinalizedMessages[0].Role)
1627+
})
1628+
1629+
t.Run("model error propagates", func(t *testing.T) {
1630+
ctrl := gomock.NewController(t)
1631+
cm := mockModel.NewMockBaseChatModel(ctrl)
1632+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1633+
Return(nil, fmt.Errorf("model error")).Times(1)
1634+
1635+
cfg := &Config{
1636+
Model: cm,
1637+
}
1638+
1639+
messages := []adk.Message{
1640+
schema.UserMessage("hello"),
1641+
}
1642+
1643+
output, err := SummarizeMessages(ctx, cfg, messages)
1644+
assert.Error(t, err)
1645+
assert.Nil(t, output)
1646+
})
1647+
1648+
t.Run("retry works in sync call", func(t *testing.T) {
1649+
ctrl := gomock.NewController(t)
1650+
cm := mockModel.NewMockBaseChatModel(ctrl)
1651+
1652+
callCount := 0
1653+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1654+
DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...any) (*schema.Message, error) {
1655+
callCount++
1656+
if callCount == 1 {
1657+
return nil, fmt.Errorf("transient error")
1658+
}
1659+
return &schema.Message{
1660+
Role: schema.Assistant,
1661+
Content: "Summary after retry",
1662+
}, nil
1663+
}).Times(2)
1664+
1665+
cfg := &Config{
1666+
Model: cm,
1667+
Retry: &RetryConfig{
1668+
MaxRetries: intPtr(2),
1669+
BackoffFunc: func(_ context.Context, _ int, _ adk.Message, _ error) time.Duration { return 0 },
1670+
},
1671+
}
1672+
1673+
messages := []adk.Message{
1674+
schema.UserMessage("hello"),
1675+
}
1676+
1677+
output, err := SummarizeMessages(ctx, cfg, messages)
1678+
assert.NoError(t, err)
1679+
assert.NotNil(t, output)
1680+
assert.Equal(t, "Summary after retry", output.ModelResponse.Content)
1681+
assert.Equal(t, 2, callCount)
1682+
})
1683+
1684+
t.Run("failover works in sync call", func(t *testing.T) {
1685+
ctrl := gomock.NewController(t)
1686+
primary := mockModel.NewMockBaseChatModel(ctrl)
1687+
failover := mockModel.NewMockBaseChatModel(ctrl)
1688+
1689+
primary.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1690+
Return(nil, fmt.Errorf("primary error")).Times(1)
1691+
failover.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1692+
Return(&schema.Message{
1693+
Role: schema.Assistant,
1694+
Content: "Summary from failover",
1695+
}, nil).Times(1)
1696+
1697+
cfg := &Config{
1698+
Model: primary,
1699+
Failover: &FailoverConfig{
1700+
GetFailoverModel: func(ctx context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) {
1701+
return failover, []*schema.Message{schema.UserMessage("failover input")}, nil
1702+
},
1703+
},
1704+
}
1705+
1706+
messages := []adk.Message{
1707+
schema.UserMessage("hello"),
1708+
}
1709+
1710+
output, err := SummarizeMessages(ctx, cfg, messages)
1711+
assert.NoError(t, err)
1712+
assert.NotNil(t, output)
1713+
assert.Equal(t, "Summary from failover", output.ModelResponse.Content)
1714+
})
1715+
1716+
t.Run("callback is invoked", func(t *testing.T) {
1717+
ctrl := gomock.NewController(t)
1718+
cm := mockModel.NewMockBaseChatModel(ctrl)
1719+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1720+
Return(&schema.Message{
1721+
Role: schema.Assistant,
1722+
Content: "Summary",
1723+
}, nil).Times(1)
1724+
1725+
callbackCalled := false
1726+
cfg := &Config{
1727+
Model: cm,
1728+
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
1729+
callbackCalled = true
1730+
assert.Len(t, before.Messages, 1)
1731+
assert.NotEmpty(t, after.Messages)
1732+
return nil
1733+
},
1734+
}
1735+
1736+
messages := []adk.Message{
1737+
schema.UserMessage("hello"),
1738+
}
1739+
1740+
output, err := SummarizeMessages(ctx, cfg, messages)
1741+
assert.NoError(t, err)
1742+
assert.NotNil(t, output)
1743+
assert.True(t, callbackCalled)
1744+
})
1745+
1746+
t.Run("custom finalize is used", func(t *testing.T) {
1747+
ctrl := gomock.NewController(t)
1748+
cm := mockModel.NewMockBaseChatModel(ctrl)
1749+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1750+
Return(&schema.Message{
1751+
Role: schema.Assistant,
1752+
Content: "Summary",
1753+
}, nil).Times(1)
1754+
1755+
cfg := &Config{
1756+
Model: cm,
1757+
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
1758+
return []adk.Message{
1759+
schema.SystemMessage("custom system"),
1760+
summary,
1761+
}, nil
1762+
},
1763+
}
1764+
1765+
messages := []adk.Message{
1766+
schema.UserMessage("hello"),
1767+
}
1768+
1769+
output, err := SummarizeMessages(ctx, cfg, messages)
1770+
assert.NoError(t, err)
1771+
assert.NotNil(t, output)
1772+
assert.Len(t, output.FinalizedMessages, 2)
1773+
assert.Equal(t, schema.System, output.FinalizedMessages[0].Role)
1774+
assert.Equal(t, "custom system", output.FinalizedMessages[0].Content)
1775+
})
1776+
1777+
t.Run("errors when EmitInternalEvents is true", func(t *testing.T) {
1778+
ctrl := gomock.NewController(t)
1779+
cm := mockModel.NewMockBaseChatModel(ctrl)
1780+
1781+
cfg := &Config{
1782+
Model: cm,
1783+
EmitInternalEvents: true,
1784+
}
1785+
1786+
output, err := SummarizeMessages(ctx, cfg, []adk.Message{schema.UserMessage("hello")})
1787+
assert.Error(t, err)
1788+
assert.Nil(t, output)
1789+
assert.Contains(t, err.Error(), "emitInternalEvents")
1790+
})
1791+
1792+
t.Run("errors when Trigger is set", func(t *testing.T) {
1793+
ctrl := gomock.NewController(t)
1794+
cm := mockModel.NewMockBaseChatModel(ctrl)
1795+
1796+
cfg := &Config{
1797+
Model: cm,
1798+
Trigger: &TriggerCondition{ContextTokens: 1000},
1799+
}
1800+
1801+
output, err := SummarizeMessages(ctx, cfg, []adk.Message{schema.UserMessage("hello")})
1802+
assert.Error(t, err)
1803+
assert.Nil(t, output)
1804+
assert.Contains(t, err.Error(), "trigger")
1805+
})
1806+
1807+
t.Run("nil model returns config check error", func(t *testing.T) {
1808+
cfg := &Config{}
1809+
1810+
output, err := SummarizeMessages(ctx, cfg, []adk.Message{schema.UserMessage("hello")})
1811+
assert.Error(t, err)
1812+
assert.Nil(t, output)
1813+
assert.Contains(t, err.Error(), "model is required")
1814+
})
1815+
1816+
t.Run("callback error propagates", func(t *testing.T) {
1817+
ctrl := gomock.NewController(t)
1818+
cm := mockModel.NewMockBaseChatModel(ctrl)
1819+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1820+
Return(&schema.Message{
1821+
Role: schema.Assistant,
1822+
Content: "Summary",
1823+
}, nil).Times(1)
1824+
1825+
cfg := &Config{
1826+
Model: cm,
1827+
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
1828+
return fmt.Errorf("callback error")
1829+
},
1830+
}
1831+
1832+
output, err := SummarizeMessages(ctx, cfg, []adk.Message{schema.UserMessage("hello")})
1833+
assert.Error(t, err)
1834+
assert.Nil(t, output)
1835+
assert.Contains(t, err.Error(), "callback error")
1836+
})
1837+
1838+
t.Run("finalize error propagates", func(t *testing.T) {
1839+
ctrl := gomock.NewController(t)
1840+
cm := mockModel.NewMockBaseChatModel(ctrl)
1841+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1842+
Return(&schema.Message{
1843+
Role: schema.Assistant,
1844+
Content: "Summary",
1845+
}, nil).Times(1)
1846+
1847+
cfg := &Config{
1848+
Model: cm,
1849+
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
1850+
return nil, fmt.Errorf("finalize error")
1851+
},
1852+
}
1853+
1854+
output, err := SummarizeMessages(ctx, cfg, []adk.Message{schema.UserMessage("hello")})
1855+
assert.Error(t, err)
1856+
assert.Nil(t, output)
1857+
assert.Contains(t, err.Error(), "finalize error")
1858+
})
1859+
1860+
t.Run("preserves system messages", func(t *testing.T) {
1861+
ctrl := gomock.NewController(t)
1862+
cm := mockModel.NewMockBaseChatModel(ctrl)
1863+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1864+
Return(&schema.Message{
1865+
Role: schema.Assistant,
1866+
Content: "Summary",
1867+
}, nil).Times(1)
1868+
1869+
cfg := &Config{
1870+
Model: cm,
1871+
}
1872+
1873+
messages := []adk.Message{
1874+
schema.SystemMessage("System 1"),
1875+
schema.SystemMessage("System 2"),
1876+
schema.UserMessage(strings.Repeat("a", 100)),
1877+
}
1878+
1879+
output, err := SummarizeMessages(ctx, cfg, messages)
1880+
assert.NoError(t, err)
1881+
assert.NotNil(t, output)
1882+
assert.Len(t, output.FinalizedMessages, 3)
1883+
assert.Equal(t, schema.System, output.FinalizedMessages[0].Role)
1884+
assert.Equal(t, "System 1", output.FinalizedMessages[0].Content)
1885+
assert.Equal(t, schema.System, output.FinalizedMessages[1].Role)
1886+
assert.Equal(t, "System 2", output.FinalizedMessages[1].Content)
1887+
})
1888+
1889+
t.Run("custom token counter is used", func(t *testing.T) {
1890+
ctrl := gomock.NewController(t)
1891+
cm := mockModel.NewMockBaseChatModel(ctrl)
1892+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1893+
Return(&schema.Message{
1894+
Role: schema.Assistant,
1895+
Content: "Summary\n<all_user_messages>\n - old\n</all_user_messages>",
1896+
}, nil).Times(1)
1897+
1898+
tokenCounterCalled := false
1899+
cfg := &Config{
1900+
Model: cm,
1901+
TokenCounter: func(ctx context.Context, input *TokenCounterInput) (int, error) {
1902+
tokenCounterCalled = true
1903+
return 42, nil
1904+
},
1905+
}
1906+
1907+
messages := []adk.Message{
1908+
schema.UserMessage("msg1"),
1909+
schema.AssistantMessage("resp", nil),
1910+
schema.UserMessage("msg2"),
1911+
}
1912+
1913+
output, err := SummarizeMessages(ctx, cfg, messages)
1914+
assert.NoError(t, err)
1915+
assert.NotNil(t, output)
1916+
assert.True(t, tokenCounterCalled)
1917+
})
1918+
}

0 commit comments

Comments
 (0)