1818
1919package org .apache .hertzbeat .ai .service .impl ;
2020
21+ import java .nio .charset .StandardCharsets ;
22+ import java .util .HashMap ;
23+ import java .util .Map ;
24+ import java .util .Objects ;
2125import lombok .extern .slf4j .Slf4j ;
2226import org .apache .hertzbeat .ai .sop .model .SopDefinition ;
2327import org .apache .hertzbeat .ai .sop .model .SopParameter ;
2731import org .apache .hertzbeat .ai .service .ChatClientProviderService ;
2832import org .apache .hertzbeat .base .dao .GeneralConfigDao ;
2933import org .apache .hertzbeat .common .entity .manager .GeneralConfig ;
34+ import org .apache .hertzbeat .common .support .event .AiProviderConfigChangeEvent ;
3035import org .apache .hertzbeat .common .util .JsonUtil ;
36+ import org .springframework .ai .chat .prompt .SystemPromptTemplate ;
3137import org .springframework .beans .factory .annotation .Value ;
3238import org .springframework .context .annotation .Lazy ;
39+ import org .springframework .context .event .EventListener ;
3340import org .springframework .core .io .Resource ;
3441import org .springframework .stereotype .Service ;
3542import org .apache .hertzbeat .ai .pojo .dto .ChatRequestContext ;
4451import reactor .core .publisher .Flux ;
4552
4653import java .io .IOException ;
47- import java .nio .charset .StandardCharsets ;
4854import java .util .ArrayList ;
4955import java .util .List ;
5056
5157/**
52- * Implementation of the {@link ChatClientProviderService}.
53- * Provides functionality to interact with the ChatClient for handling chat
54- * messages.
58+ * Implementation of the {@link ChatClientProviderService}. Provides functionality to interact with the ChatClient for
59+ * handling chat messages.
5560 */
5661@ Slf4j
5762@ Service
@@ -64,21 +69,28 @@ public class ChatClientProviderServiceImpl implements ChatClientProviderService
6469
6570 private final GeneralConfigDao generalConfigDao ;
6671
72+ private ModelProviderConfig modelProviderConfig ;
73+
74+
6775 private final SkillRegistry skillRegistry ;
68-
76+
6977 @ Autowired
7078 @ Qualifier ("hertzbeatTools" )
7179 private ToolCallbackProvider toolCallbackProvider ;
72-
80+
7381 private boolean isConfigured = false ;
7482
7583 @ Value ("classpath:/prompt/system-message.st" )
7684 private Resource systemResource ;
7785
86+ @ Value ("classpath:/prompt/extra-message-protected.st" )
87+ private Resource extraResourceProtected ;
88+
89+
7890 @ Autowired
79- public ChatClientProviderServiceImpl (ApplicationContext applicationContext ,
80- GeneralConfigDao generalConfigDao ,
81- @ Lazy SkillRegistry skillRegistry ) {
91+ public ChatClientProviderServiceImpl (ApplicationContext applicationContext ,
92+ GeneralConfigDao generalConfigDao ,
93+ @ Lazy SkillRegistry skillRegistry ) {
8294 this .applicationContext = applicationContext ;
8395 this .generalConfigDao = generalConfigDao ;
8496 this .skillRegistry = skillRegistry ;
@@ -89,7 +101,7 @@ public Flux<String> streamChat(ChatRequestContext context) {
89101 try {
90102 // Get the current (potentially refreshed) ChatClient instance
91103 ChatClient chatClient = applicationContext .getBean ("openAiChatClient" , ChatClient .class );
92-
104+
93105 List <Message > messages = new ArrayList <>();
94106
95107 // Add conversation history if available
@@ -112,13 +124,13 @@ public Flux<String> streamChat(ChatRequestContext context) {
112124 String systemPrompt = buildSystemPrompt (context .getConversationId ());
113125
114126 return chatClient .prompt ()
115- .messages (messages )
116- .system (systemPrompt )
117- .toolCallbacks (toolCallbackProvider )
118- .stream ()
119- .content ()
120- .doOnComplete (() -> log .info ("Streaming completed for conversation: {}" , context .getConversationId ()))
121- .doOnError (error -> log .error ("Error in streaming chat: {}" , error .getMessage (), error ));
127+ .messages (messages )
128+ .system (systemPrompt )
129+ .toolCallbacks (toolCallbackProvider )
130+ .stream ()
131+ .content ()
132+ .doOnComplete (() -> log .info ("Streaming completed for conversation: {}" , context .getConversationId ()))
133+ .doOnError (error -> log .error ("Error in streaming chat: {}" , error .getMessage (), error ));
122134
123135 } catch (Exception e ) {
124136 log .error ("Error setting up streaming chat: {}" , e .getMessage (), e );
@@ -133,30 +145,43 @@ private String buildSystemPrompt(Long conversationId) {
133145 try {
134146 String template = systemResource .getContentAsString (StandardCharsets .UTF_8 );
135147 String skillsList = generateSkillsList ();
136- return template
137- .replace (SKILLS_PLACEHOLDER , skillsList )
138- .replace (CONVERSATION_ID_PLACEHOLDER , String .valueOf (conversationId ));
148+ template = template
149+ .replace (SKILLS_PLACEHOLDER , skillsList )
150+ .replace (CONVERSATION_ID_PLACEHOLDER , String .valueOf (conversationId ));
151+
152+ // add extra prompt for protected model to guide it to use protected tools
153+ if (Objects .equals (modelProviderConfig .getParticipationModel (), "PROTECTED" )) {
154+ Map <String , Object > metadata = new HashMap <>();
155+ metadata .put ("conversationId" , conversationId );
156+ return template + SystemPromptTemplate .builder ().resource (extraResourceProtected ).build ()
157+ .create (metadata )
158+ .getContents ();
159+ } else {
160+ return template ;
161+ }
162+
139163 } catch (IOException e ) {
140164 log .error ("Failed to read system prompt template: {}" , e .getMessage ());
141165 return "" ;
142166 }
143167 }
144168
169+
145170 /**
146171 * Generate a formatted list of available skills for the system prompt.
147172 */
148173 private String generateSkillsList () {
149174 List <SopDefinition > skills = skillRegistry .getAllSkills ();
150-
175+
151176 if (skills .isEmpty ()) {
152177 return "No skills currently available. Use listSkills tool to refresh." ;
153178 }
154-
179+
155180 StringBuilder sb = new StringBuilder ();
156181 for (SopDefinition skill : skills ) {
157182 sb .append ("- **" ).append (skill .getName ()).append ("**: " );
158183 sb .append (skill .getDescription ());
159-
184+
160185 // Add parameter hints
161186 if (skill .getParameters () != null && !skill .getParameters ().isEmpty ()) {
162187 sb .append (" (requires: " );
@@ -171,16 +196,24 @@ private String generateSkillsList() {
171196 }
172197 sb .append ("\n " );
173198 }
174-
199+
175200 return sb .toString ();
176201 }
177202
203+ @ EventListener (AiProviderConfigChangeEvent .class )
204+ public void onAiProviderConfigChange (AiProviderConfigChangeEvent event ) {
205+ GeneralConfig providerConfig = generalConfigDao .findByType ("provider" );
206+ this .modelProviderConfig = JsonUtil .fromJson (providerConfig .getContent (), ModelProviderConfig .class );
207+ }
208+
178209 @ Override
179210 public boolean isConfigured () {
180211 if (!isConfigured ) {
181212 GeneralConfig providerConfig = generalConfigDao .findByType ("provider" );
182- ModelProviderConfig modelProviderConfig = JsonUtil .fromJson (providerConfig .getContent (), ModelProviderConfig .class );
183- isConfigured = modelProviderConfig != null && modelProviderConfig .getApiKey () != null ;
213+ ModelProviderConfig modelProviderConfig = JsonUtil .fromJson (providerConfig .getContent (),
214+ ModelProviderConfig .class );
215+ isConfigured = modelProviderConfig != null && modelProviderConfig .getApiKey () != null ;
216+ this .modelProviderConfig = modelProviderConfig ;
184217 }
185218 return isConfigured ;
186219 }
0 commit comments