Skip to content

Commit c583c12

Browse files
authored
Merge pull request #24 from devoxx/issue-23
Fix #23: Refactoring getModelNames() + Added new model names for Groq…
2 parents 0b71284 + f3d5426 commit c583c12

File tree

12 files changed

+172
-69
lines changed

12 files changed

+172
-69
lines changed

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ plugins {
55
}
66

77
group = "com.devoxx.genie"
8-
version = "0.0.9"
8+
version = "0.0.10"
99

1010
repositories {
1111
mavenCentral()

src/main/java/com/devoxx/genie/chatmodel/ChatModelFactory.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import com.devoxx.genie.ui.SettingsState;
66
import dev.langchain4j.model.chat.ChatLanguageModel;
77

8+
import java.util.List;
9+
810
public interface ChatModelFactory {
911

1012
/**
@@ -14,6 +16,12 @@ public interface ChatModelFactory {
1416
*/
1517
ChatLanguageModel createChatModel(ChatModel chatModel);
1618

19+
/**
20+
* List the available model names.
21+
* @return the list of model names
22+
*/
23+
List<String> getModelNames();
24+
1725
/**
1826
* Get the base URL by the model type.
1927
* @param modelProvider the language model provider

src/main/java/com/devoxx/genie/chatmodel/anthropic/AnthropicChatModelFactory.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,18 @@
55
import dev.langchain4j.model.anthropic.AnthropicChatModel;
66
import dev.langchain4j.model.chat.ChatLanguageModel;
77

8+
import java.util.List;
9+
10+
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.*;
11+
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_INSTANT_1_2;
12+
813
public class AnthropicChatModelFactory implements ChatModelFactory {
914

10-
private final String apiKey;
11-
private final String modelName;
15+
private String apiKey;
16+
private String modelName;
17+
18+
public AnthropicChatModelFactory() {
19+
}
1220

1321
public AnthropicChatModelFactory(String apiKey, String modelName) {
1422
this.apiKey = apiKey;
@@ -27,4 +35,16 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) {
2735
.build();
2836
}
2937

38+
@Override
39+
public List<String> getModelNames() {
40+
return List.of(
41+
CLAUDE_3_OPUS_20240229.toString(),
42+
CLAUDE_3_SONNET_20240229.toString(),
43+
CLAUDE_3_HAIKU_20240307.toString(),
44+
CLAUDE_2_1.toString(),
45+
CLAUDE_2.toString(),
46+
CLAUDE_INSTANT_1_2.toString()
47+
);
48+
}
49+
3050
}

src/main/java/com/devoxx/genie/chatmodel/deepinfra/DeepInfraChatModelFactory.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,24 @@
22

33
import com.devoxx.genie.chatmodel.ChatModelFactory;
44
import com.devoxx.genie.model.ChatModel;
5+
import com.devoxx.genie.model.ollama.OllamaModelEntryDTO;
6+
import com.devoxx.genie.service.OllamaService;
7+
import com.devoxx.genie.ui.util.NotificationUtil;
8+
import com.intellij.openapi.project.ProjectManager;
59
import dev.langchain4j.model.chat.ChatLanguageModel;
610
import dev.langchain4j.model.openai.OpenAiChatModel;
711

12+
import java.io.IOException;
813
import java.time.Duration;
14+
import java.util.ArrayList;
15+
import java.util.List;
916

1017
public class DeepInfraChatModelFactory implements ChatModelFactory {
1118

12-
private final String apiKey;
13-
private final String modelName;
19+
private String apiKey;
20+
private String modelName;
21+
22+
public DeepInfraChatModelFactory() {}
1423

1524
public DeepInfraChatModelFactory(String apiKey, String modelName) {
1625
this.apiKey = apiKey;
@@ -30,4 +39,21 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) {
3039
.topP(chatModel.topP)
3140
.build();
3241
}
42+
43+
@Override
44+
public List<String> getModelNames() {
45+
return List.of(
46+
"meta-llama/Meta-Llama-3-70B-Instruct",
47+
"meta-llama/Meta-Llama-3-8B-Instruct",
48+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
49+
"mistralai/Mixtral-8x22B-Instruct-v0.1",
50+
"microsoft/WizardLM-2-8x22B",
51+
"microsoft/WizardLM-2-7B",
52+
"databricks/dbrx-instruct",
53+
"openchat/openchat_3.5",
54+
"google/gemma-7b-it",
55+
"Phind/Phind-CodeLlama-34B-v2",
56+
"bigcode/starcoder2-15b"
57+
);
58+
}
3359
}

src/main/java/com/devoxx/genie/chatmodel/gpt4all/GPT4AllChatModelFactory.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import dev.langchain4j.model.localai.LocalAiChatModel;
99

1010
import java.time.Duration;
11+
import java.util.List;
1112

1213
public class GPT4AllChatModelFactory implements ChatModelFactory {
1314

@@ -23,4 +24,9 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) {
2324
.topP(chatModel.topP)
2425
.build();
2526
}
27+
28+
@Override
29+
public List<String> getModelNames() {
30+
return List.of();
31+
}
2632
}

src/main/java/com/devoxx/genie/chatmodel/groq/GroqChatModelFactory.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
import dev.langchain4j.model.openai.OpenAiChatModel;
77

88
import java.time.Duration;
9+
import java.util.List;
910

1011
public class GroqChatModelFactory implements ChatModelFactory {
1112

12-
private final String apiKey;
13-
private final String modelName;
13+
private String apiKey;
14+
private String modelName;
15+
16+
public GroqChatModelFactory() {}
1417

1518
public GroqChatModelFactory(String apiKey, String modelName) {
1619
this.apiKey = apiKey;
@@ -30,4 +33,9 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) {
3033
.topP(chatModel.topP)
3134
.build();
3235
}
36+
37+
@Override
38+
public List<String> getModelNames() {
39+
return List.of("gemma-7b-it", "llama3-8b-8192", "llama3-70b-8192", "llama2-70b-4096", "mixtral-8x7b-32768");
40+
}
3341
}

src/main/java/com/devoxx/genie/chatmodel/lmstudio/LMStudioChatModelFactory.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dev.langchain4j.model.localai.LocalAiChatModel;
88

99
import java.time.Duration;
10+
import java.util.List;
1011

1112
public class LMStudioChatModelFactory implements ChatModelFactory {
1213

@@ -22,4 +23,9 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) {
2223
.timeout(Duration.ofSeconds(chatModel.timeout))
2324
.build();
2425
}
26+
27+
@Override
28+
public List<String> getModelNames() {
29+
return List.of("");
30+
}
2531
}

src/main/java/com/devoxx/genie/chatmodel/mistral/MistralChatModelFactory.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
import dev.langchain4j.model.mistralai.MistralAiChatModel;
77

88
import java.time.Duration;
9+
import java.util.List;
10+
11+
import static dev.langchain4j.model.mistralai.MistralAiChatModelName.*;
912

1013
public class MistralChatModelFactory implements ChatModelFactory {
1114

12-
private final String apiKey;
13-
private final String modelName;
15+
private String apiKey;
16+
private String modelName;
17+
18+
public MistralChatModelFactory() {
19+
}
1420

1521
public MistralChatModelFactory(String apiKey, String modelName) {
1622
this.apiKey = apiKey;
@@ -29,4 +35,14 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) {
2935
.topP(chatModel.topP)
3036
.build();
3137
}
38+
39+
@Override
40+
public List<String> getModelNames() {
41+
return List.of(
42+
OPEN_MISTRAL_7B.toString(),
43+
OPEN_MIXTRAL_8x7B.toString(),
44+
MISTRAL_SMALL_LATEST.toString(),
45+
MISTRAL_MEDIUM_LATEST.toString()
46+
);
47+
}
3248
}

src/main/java/com/devoxx/genie/chatmodel/ollama/OllamaChatModelFactory.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33
import com.devoxx.genie.chatmodel.ChatModelFactory;
44
import com.devoxx.genie.model.ChatModel;
55
import com.devoxx.genie.model.enumarations.ModelProvider;
6+
import com.devoxx.genie.model.ollama.OllamaModelEntryDTO;
7+
import com.devoxx.genie.service.OllamaService;
8+
import com.devoxx.genie.ui.util.NotificationUtil;
9+
import com.intellij.openapi.project.ProjectManager;
610
import dev.langchain4j.model.chat.ChatLanguageModel;
711
import dev.langchain4j.model.ollama.OllamaChatModel;
812

13+
import java.io.IOException;
914
import java.time.Duration;
15+
import java.util.ArrayList;
16+
import java.util.List;
1017

1118
public class OllamaChatModelFactory implements ChatModelFactory {
1219

@@ -21,4 +28,19 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) {
2128
.timeout(Duration.ofSeconds(chatModel.timeout))
2229
.build();
2330
}
31+
32+
@Override
33+
public List<String> getModelNames() {
34+
List<String> modelNames = new ArrayList<>();
35+
try {
36+
OllamaModelEntryDTO[] ollamaModels = new OllamaService().getModels();
37+
for (OllamaModelEntryDTO model : ollamaModels) {
38+
modelNames.add(model.getName());
39+
}
40+
} catch (IOException e) {
41+
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(),
42+
"Ollama is not running, please start it.");
43+
}
44+
return modelNames;
45+
}
2446
}

src/main/java/com/devoxx/genie/chatmodel/openai/OpenAIChatModelFactory.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
import com.devoxx.genie.model.ChatModel;
55
import dev.langchain4j.model.chat.ChatLanguageModel;
66
import dev.langchain4j.model.openai.OpenAiChatModel;
7+
import dev.langchain4j.model.openai.OpenAiChatModelName;
78

89
import java.time.Duration;
10+
import java.util.List;
911

1012
public class OpenAIChatModelFactory implements ChatModelFactory {
1113

12-
private final String apiKey;
13-
private final String modelName;
14+
private String apiKey;
15+
private String modelName;
16+
17+
public OpenAIChatModelFactory() {}
1418

1519
public OpenAIChatModelFactory(String apiKey, String modelName) {
1620
this.apiKey = apiKey;
@@ -29,4 +33,14 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) {
2933
.topP(chatModel.topP)
3034
.build();
3135
}
36+
37+
@Override
38+
public List<String> getModelNames() {
39+
return List.of(
40+
OpenAiChatModelName.GPT_4.toString(),
41+
OpenAiChatModelName.GPT_4_32K.toString(),
42+
OpenAiChatModelName.GPT_4_TURBO_PREVIEW.toString(),
43+
OpenAiChatModelName.GPT_3_5_TURBO.toString(),
44+
OpenAiChatModelName.GPT_3_5_TURBO_16K.toString());
45+
}
3246
}

0 commit comments

Comments
 (0)