Skip to content

Commit 60bf9bc

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into refactor-openai
2 parents 65384d7 + 8c2876f commit 60bf9bc

File tree

4 files changed

+56
-7
lines changed

4 files changed

+56
-7
lines changed

src/main/java/org/devlive/sdk/common/DefaultClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ public AssistantsEntity createAssistants(AssistantsEntity configure)
266266
}
267267

268268
public AssistantsFileEntity createAssistantsFile(String fileId,
269-
String assistantId)
269+
String assistantId)
270270
{
271271
String url = String.format(ProviderUtils.getUrl(provider, UrlModel.FETCH_ASSISTANTS_FILES), assistantId);
272272
Map<String, String> configure = Maps.newHashMap();

src/main/java/org/devlive/sdk/openai/OpenAiClient.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ public OpenAiClientBuilder model(CompletionModel model)
185185
return this;
186186
}
187187

188+
public OpenAiClientBuilder model(String model)
189+
{
190+
this.model = model;
191+
return this;
192+
}
193+
188194
private String getDefaultHost()
189195
{
190196
if (ObjectUtils.isEmpty(this.provider)) {

src/main/java/org/devlive/sdk/openai/entity/ChatEntity.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import org.devlive.sdk.openai.utils.EnumsUtils;
1414

1515
import java.util.List;
16+
import java.util.Objects;
17+
18+
import static org.devlive.sdk.openai.model.CompletionModel.GPT_35_TURBO;
1619

1720
@Data
1821
@Builder
@@ -47,7 +50,7 @@ public class ChatEntity
4750
private ChatEntity(ChatEntityBuilder builder)
4851
{
4952
if (ObjectUtils.isEmpty(builder.model)) {
50-
builder.model(CompletionModel.GPT_35_TURBO);
53+
builder.model(GPT_35_TURBO);
5154
}
5255
this.model = builder.model;
5356
this.messages = builder.messages;
@@ -73,7 +76,7 @@ public static class ChatEntityBuilder
7376
public ChatEntityBuilder model(CompletionModel model)
7477
{
7578
if (ObjectUtils.isEmpty(model)) {
76-
model = CompletionModel.GPT_35_TURBO;
79+
model = GPT_35_TURBO;
7780
}
7881
switch (model) {
7982
case GPT_35_TURBO:
@@ -96,6 +99,12 @@ public ChatEntityBuilder model(CompletionModel model)
9699
return this;
97100
}
98101

102+
public ChatEntityBuilder model(String model)
103+
{
104+
this.model = model;
105+
return this;
106+
}
107+
99108
public ChatEntityBuilder temperature(Double temperature)
100109
{
101110
if (temperature < 0 || temperature > 2) {
@@ -108,11 +117,20 @@ public ChatEntityBuilder temperature(Double temperature)
108117
public ChatEntityBuilder maxTokens(Integer maxTokens)
109118
{
110119
CompletionModel completionModel = EnumsUtils.getCompleteModel(this.model);
111-
if (ObjectUtils.isNotEmpty(this.model) && maxTokens > completionModel.getMaxTokens()) {
112-
throw new ParamException(String.format("Invalid maxTokens: %s, Cannot be larger than the model default configuration %s", maxTokens, completionModel.getMaxTokens()));
120+
if (Objects.isNull(completionModel)) {
121+
this.maxTokens = maxTokens;
122+
return this;
123+
}
124+
else {
125+
if (ObjectUtils.isNotEmpty(this.model)
126+
&& maxTokens > completionModel.getMaxTokens()) {
127+
throw new ParamException(String.format(
128+
"Invalid maxTokens: %s, Cannot be larger than the model default configuration %s",
129+
maxTokens, completionModel.getMaxTokens()));
130+
}
131+
this.maxTokens = maxTokens;
132+
return this;
113133
}
114-
this.maxTokens = maxTokens;
115-
return this;
116134
}
117135

118136
private ChatEntityBuilder stream()

src/test/java/org/devlive/sdk/openai/OpenAiClientTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.devlive.sdk.common.exception.RequestException;
1616
import org.devlive.sdk.openai.model.CompletionModel;
1717
import org.devlive.sdk.openai.model.EditModel;
18+
import org.devlive.sdk.openai.response.ChatResponse;
1819
import org.junit.Assert;
1920
import org.junit.Before;
2021
import org.junit.Test;
@@ -94,6 +95,30 @@ public void testCreateCompletion()
9495
Assert.assertTrue(client.createCompletion(configure).getChoices().size() > 0);
9596
}
9697

98+
@Test
99+
public void testCustomizedModel()
100+
{
101+
client = OpenAiClient.builder()
102+
.apiHost(System.getProperty("proxy.host"))
103+
.apiKey(System.getProperty("openai.token"))
104+
.model("text-davinci-003")
105+
.build();
106+
107+
List<MessageEntity> messages = Lists.newArrayList();
108+
messages.add(MessageEntity.builder()
109+
.content("Hello, please show me a jok!")
110+
.build());
111+
112+
ChatEntity configure = ChatEntity.builder()
113+
.messages(messages)
114+
.model("text-davinci-003")
115+
.build();
116+
ChatResponse chatCompletion = client.createChatCompletion(configure);
117+
String content = chatCompletion.getChoices().get(0).getMessage().getContent();
118+
// System.out.println(content);
119+
Assert.assertNotNull(content);
120+
}
121+
97122
@Test
98123
public void testCreateChatCompletion()
99124
{

0 commit comments

Comments
 (0)