Skip to content

Commit 7088a83

Browse files
Merge pull request #40 from mwetzka03/feature/rag-integration
WWI22 P2T Gruppe RAG-Integration
2 parents fa521b0 + 14181f2 commit 7088a83

File tree

4 files changed

+270
-37
lines changed

4 files changed

+270
-37
lines changed

pom.xml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,16 @@
178178
<version>1.0</version>
179179
</dependency>
180180
<dependency>
181-
<groupId>org.jsoup</groupId>
182-
<artifactId>jsoup</artifactId>
183-
<version>1.10.2</version>
181+
<groupId>org.wiremock</groupId>
182+
<artifactId>wiremock-standalone</artifactId>
183+
<version>3.0.0</version>
184+
<scope>test</scope>
185+
</dependency>
186+
<dependency>
187+
<groupId>org.wiremock.integrations</groupId>
188+
<artifactId>wiremock-spring-boot</artifactId>
189+
<version>3.0.0</version>
190+
<scope>test</scope>
184191
</dependency>
185192
<dependency>
186193
<groupId>org.junit.jupiter</groupId>

src/main/java/de/dhbw/woped/process2text/controller/P2TController.java

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,18 @@
1111
import org.slf4j.Logger;
1212
import org.slf4j.LoggerFactory;
1313
import org.springframework.beans.factory.annotation.Autowired;
14+
import org.springframework.http.HttpEntity;
15+
import org.springframework.http.HttpHeaders;
16+
import org.springframework.http.MediaType;
17+
import org.springframework.http.ResponseEntity;
1418
import org.springframework.beans.factory.annotation.Qualifier;
15-
import org.springframework.web.bind.annotation.*;
19+
import org.springframework.web.bind.annotation.CrossOrigin;
20+
import org.springframework.web.bind.annotation.GetMapping;
21+
import org.springframework.web.bind.annotation.PostMapping;
22+
import org.springframework.web.bind.annotation.RequestBody;
23+
import org.springframework.web.bind.annotation.RequestParam;
24+
import org.springframework.web.bind.annotation.RestController;
25+
import org.springframework.web.client.RestTemplate;
1626
import org.springframework.web.server.ResponseStatusException;
1727

1828
/**
@@ -66,6 +76,8 @@ protected String generateText(@RequestBody String body) {
6676
* @param apiKey The API key for OpenAI.
6777
* @param prompt The prompt to guide the translation.
6878
* @param gptModel The GPT model to be used for translation.
79+
* @param provider The provider to use (e.g., "openAi", "lmStudio").
80+
* @param useRag Whether to use RAG (Retrieval-Augmented Generation) to enrich the prompt.
6981
* @return The translated text.
7082
*/
7183
@ApiOperation(
@@ -75,27 +87,66 @@ protected String generateText(@RequestBody String body) {
7587
@PostMapping(value = "/generateTextLLM", consumes = "text/plain", produces = "text/plain")
7688
protected String generateTextLLM(
7789
@RequestBody String body,
78-
@RequestParam(required = true) String apiKey,
90+
@RequestParam(required = false) String apiKey,
7991
@RequestParam(required = true) String prompt,
80-
@RequestParam(required = true) String gptModel) {
81-
httpRequestsTotal.increment();
82-
return httpRequestDuration.record(
83-
() -> {
84-
logger.debug(
85-
"Received request with prompt: {}, gptModel: {}, body: {}",
86-
prompt,
87-
gptModel,
88-
body.replaceAll("[\n\r\t]", "_"));
89-
OpenAiApiDTO openAiApiDTO = new OpenAiApiDTO(apiKey, gptModel, prompt);
90-
try {
91-
String response = llmService.callLLM(body, openAiApiDTO);
92-
logger.debug("LLM Response: " + response);
93-
return response;
94-
} catch (ResponseStatusException e) {
95-
logger.error("Error processing LLM request", e);
96-
throw e;
97-
}
98-
});
92+
@RequestParam(required = true) String gptModel,
93+
@RequestParam(required = true) String provider,
94+
@RequestParam(required = true) boolean useRag) {
95+
logger.debug(
96+
"Received request with apiKey: {}, prompt: {}, gptModel: {}, provider: {}, useRag: {}, body: {}",
97+
apiKey,
98+
prompt,
99+
gptModel,
100+
provider,
101+
useRag,
102+
body.replaceAll("[\n\r\t]", "_"));
103+
104+
String enrichedPrompt = prompt;
105+
106+
if (useRag) {
107+
try {
108+
RestTemplate restTemplate = new RestTemplate();
109+
// JSON body for the RAG service
110+
org.json.JSONObject requestJson = new org.json.JSONObject();
111+
requestJson.put("prompt", prompt);
112+
requestJson.put("diagram", body);
113+
114+
HttpHeaders headers = new HttpHeaders();
115+
headers.setContentType(MediaType.APPLICATION_JSON);
116+
HttpEntity<String> entity = new HttpEntity<>(requestJson.toString(), headers);
117+
118+
// POST to the RAG service
119+
String ragServiceUrl = System.getProperty("rag.service.url", "http://localhost:5000");
120+
ResponseEntity<String> ragResponse =
121+
restTemplate.postForEntity(ragServiceUrl + "/rag/enrich", entity, String.class);
122+
123+
// Expected: {"enriched_prompt": "..."}
124+
org.json.JSONObject responseJson = new org.json.JSONObject(ragResponse.getBody());
125+
enrichedPrompt = responseJson.getString("enriched_prompt");
126+
logger.info("RAG service enriched prompt successfully. Original length: {}, Enriched length: {}",
127+
prompt.length(), enrichedPrompt.length());
128+
logger.debug("Enriched prompt: {}", enrichedPrompt);
129+
} catch (Exception e) {
130+
logger.error("Error calling RAG service, falling back to original prompt", e);
131+
}
132+
}
133+
134+
OpenAiApiDTO openAiApiDTO;
135+
if (provider.equalsIgnoreCase("lmStudio")) {
136+
137+
openAiApiDTO = new OpenAiApiDTO(null, gptModel, enrichedPrompt, provider, useRag);
138+
} else {
139+
openAiApiDTO = new OpenAiApiDTO(apiKey, gptModel, enrichedPrompt, provider, useRag);
140+
}
141+
142+
try {
143+
String response = llmService.callLLM(body, openAiApiDTO);
144+
logger.debug("LLM Response: " + response);
145+
return response;
146+
} catch (ResponseStatusException e) {
147+
logger.error("Error processing LLM request", e);
148+
throw e;
149+
}
99150
}
100151

101152
/**
Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,60 @@
11
package de.dhbw.woped.process2text.model.process;
22

3-
import lombok.AllArgsConstructor;
4-
import lombok.Data;
5-
import lombok.Getter;
6-
import lombok.Setter;
7-
83
/** Data Transfer Object to hold OpenAI API related information. */
9-
@Data
10-
@Setter
11-
@Getter
12-
@AllArgsConstructor
134
public class OpenAiApiDTO {
145

6+
public OpenAiApiDTO(
7+
String apiKey, String gptModel, String prompt, String provider, boolean useRAG) {
8+
this.apiKey = apiKey;
9+
this.gptModel = gptModel;
10+
this.prompt = prompt;
11+
this.provider = provider;
12+
this.useRAG = useRAG;
13+
}
14+
1515
private String apiKey;
1616
private String gptModel;
1717
private String prompt;
18+
private String provider;
19+
private boolean useRAG;
20+
21+
public String getApiKey() {
22+
return apiKey;
23+
}
24+
25+
public String getGptModel() {
26+
return gptModel;
27+
}
28+
29+
public String getPrompt() {
30+
return prompt;
31+
}
32+
33+
public void setGptModel(String gptModel) {
34+
this.gptModel = gptModel;
35+
}
36+
37+
public void setApiKey(String apiKey) {
38+
this.apiKey = apiKey;
39+
}
40+
41+
public void setPrompt(String prompt) {
42+
this.prompt = prompt;
43+
}
44+
45+
public String getProvider() {
46+
return provider;
47+
}
48+
49+
public void setProvider(String provider) {
50+
this.provider = provider;
51+
}
52+
53+
public boolean isUseRAG() {
54+
return useRAG;
55+
}
56+
57+
public void setUseRAG(boolean useRAG) {
58+
this.useRAG = useRAG;
59+
}
1860
}
Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,144 @@
11
package de.dhbw.woped.process2text;
22

3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.containing;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
7+
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
8+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
9+
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
10+
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
11+
312
import org.junit.jupiter.api.Test;
13+
import org.springframework.beans.factory.annotation.Autowired;
414
import org.springframework.boot.test.context.SpringBootTest;
15+
import org.springframework.boot.test.web.client.TestRestTemplate;
16+
import org.springframework.boot.test.web.server.LocalServerPort;
17+
import org.springframework.http.HttpEntity;
18+
import org.springframework.http.HttpHeaders;
19+
import org.springframework.http.MediaType;
20+
import org.wiremock.spring.ConfigureWireMock;
21+
import org.wiremock.spring.EnableWireMock;
22+
23+
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
24+
@EnableWireMock({
25+
@ConfigureWireMock(name = "rag-service", port = Process2textApplicationTests.RAG_SERVICE_PORT)
26+
})
27+
class Process2textApplicationTests {
28+
29+
public static final int RAG_SERVICE_PORT = 5000;
30+
private static final int HTTP_OK = 200;
531

6-
@SpringBootTest
7-
class Process2TextApplicationTests {
32+
@LocalServerPort
33+
private int port;
834

35+
@Autowired
36+
private TestRestTemplate restTemplate;
37+
38+
/**
39+
* Test: RAG service is called when useRAG=true
40+
*/
941
@Test
10-
void contextLoads() {}
11-
}
42+
void testRAGServiceCalledWhenEnabled() {
43+
// Mock RAG Service on port 5000
44+
stubFor(
45+
post(urlPathEqualTo("/rag/enrich"))
46+
.withHeader("Content-Type", equalTo("application/json"))
47+
.willReturn(
48+
aResponse()
49+
.withStatus(HTTP_OK)
50+
.withHeader("Content-Type", "application/json")
51+
.withBody("{\"enriched_prompt\":\"Enhanced prompt with RAG context\"}")));
52+
53+
// BPMN test content
54+
String bpmnBody = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"
55+
+ "<bpmn2:definitions xmlns:bpmn2=\"http://www.omg.org/spec/BPMN/20100524/MODEL\">\n"
56+
+ " <bpmn2:process id=\"test-process\" name=\"Test Process\" />\n"
57+
+ "</bpmn2:definitions>";
58+
59+
// Set system property for RAG service URL
60+
System.setProperty("rag.service.url", "http://localhost:" + RAG_SERVICE_PORT);
61+
62+
// Set HTTP headers
63+
HttpHeaders headers = new HttpHeaders();
64+
headers.setContentType(MediaType.TEXT_PLAIN);
65+
66+
HttpEntity<String> request = new HttpEntity<>(bpmnBody, headers);
67+
68+
// Call the real generateTextLLM method via HTTP with RAG enabled
69+
try {
70+
String url = "http://localhost:" + port + "/p2t/generateTextLLM"
71+
+ "?apiKey=sk-testapikey"
72+
+ "&prompt=Analyze this BPMN process"
73+
+ "&gptModel=gpt-3.5-turbo"
74+
+ "&provider=openAi"
75+
+ "&useRag=true";
76+
77+
String result = restTemplate.postForObject(url, request, String.class);
78+
79+
// Check that a response is returned
80+
System.out.println("Response: " + result);
81+
} catch (Exception e) {
82+
// Ignore OpenAI errors, important is only that RAG was called
83+
System.out.println("Expected error (OpenAI mock): " + e.getMessage());
84+
}
85+
86+
// Verify: RAG Service was called
87+
verify(postRequestedFor(urlPathEqualTo("/rag/enrich"))
88+
.withHeader("Content-Type", equalTo("application/json")));
89+
90+
// Check that the RAG request contains the correct prompt and diagram
91+
verify(postRequestedFor(urlPathEqualTo("/rag/enrich"))
92+
.withRequestBody(containing("Analyze this BPMN process"))
93+
.withRequestBody(containing("test-process")));
94+
}
95+
96+
/**
97+
* Test: RAG service is NOT called when useRAG=false
98+
*/
99+
@Test
100+
void testRAGServiceNotCalledWhenDisabled() {
101+
// Mock RAG Service (but it won't be called)
102+
stubFor(
103+
post(urlPathEqualTo("/rag/enrich"))
104+
.withHeader("Content-Type", equalTo("application/json"))
105+
.willReturn(
106+
aResponse()
107+
.withStatus(HTTP_OK)
108+
.withHeader("Content-Type", "application/json")
109+
.withBody("{\"enriched_prompt\":\"This should not be called\"}")));
110+
111+
// BPMN test content
112+
String bpmnBody = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"
113+
+ "<bpmn2:definitions xmlns:bpmn2=\"http://www.omg.org/spec/BPMN/20100524/MODEL\">\n"
114+
+ " <bpmn2:process id=\"test-process\" name=\"Test Process\" />\n"
115+
+ "</bpmn2:definitions>";
116+
117+
// Set HTTP headers
118+
HttpHeaders headers = new HttpHeaders();
119+
headers.setContentType(MediaType.TEXT_PLAIN);
120+
121+
HttpEntity<String> request = new HttpEntity<>(bpmnBody, headers);
122+
123+
// Call the real generateTextLLM method via HTTP with RAG disabled
124+
try {
125+
String url = "http://localhost:" + port + "/p2t/generateTextLLM"
126+
+ "?apiKey=sk-testapikey"
127+
+ "&prompt=Analyze this BPMN process"
128+
+ "&gptModel=gpt-3.5-turbo"
129+
+ "&provider=openAi"
130+
+ "&useRag=false";
131+
132+
String result = restTemplate.postForObject(url, request, String.class);
133+
134+
// Check that a response is returned
135+
System.out.println("Response: " + result);
136+
} catch (Exception e) {
137+
// Ignore OpenAI errors, important is only that RAG was NOT called
138+
System.out.println("Expected error (OpenAI mock): " + e.getMessage());
139+
}
140+
141+
// Verify: RAG Service was NOT called
142+
verify(0, postRequestedFor(urlPathEqualTo("/rag/enrich")));
143+
}
144+
}

0 commit comments

Comments
 (0)