diff --git a/lib/consts.dart b/lib/consts.dart index 628598e47..ae60ba4b2 100644 --- a/lib/consts.dart +++ b/lib/consts.dart @@ -149,6 +149,11 @@ enum ImportFormat { const ImportFormat(this.label); final String label; } +enum LLMProvider { + ollama, + gemini, + openai +} const String kGlobalEnvironmentId = "global"; diff --git a/lib/dashbot/providers/dashbot_providers.dart b/lib/dashbot/providers/dashbot_providers.dart index 2015d6a82..fe5654338 100644 --- a/lib/dashbot/providers/dashbot_providers.dart +++ b/lib/dashbot/providers/dashbot_providers.dart @@ -1,16 +1,18 @@ import 'dart:convert'; import 'package:flutter_riverpod/flutter_riverpod.dart'; import 'package:shared_preferences/shared_preferences.dart'; +import '../../consts.dart'; import '../services/dashbot_service.dart'; +// Chat Messages Provider final chatMessagesProvider = - StateNotifierProvider>>( - (ref) => ChatMessagesNotifier(), -); +StateNotifierProvider>>( + (ref) => ChatMessagesNotifier()); + -final dashBotServiceProvider = Provider((ref) { - return DashBotService(); -}); +final selectedLLMProvider = +StateNotifierProvider( + (ref) => SelectedLLMNotifier()); class ChatMessagesNotifier extends StateNotifier>> { ChatMessagesNotifier() : super([]) { @@ -20,10 +22,14 @@ class ChatMessagesNotifier extends StateNotifier>> { static const _storageKey = 'chatMessages'; Future _loadMessages() async { - final prefs = await SharedPreferences.getInstance(); - final messages = prefs.getString(_storageKey); - if (messages != null) { - state = List>.from(json.decode(messages)); + try { + final prefs = await SharedPreferences.getInstance(); + final messages = prefs.getString(_storageKey); + if (messages != null) { + state = List>.from(json.decode(messages)); + } + } catch (e) { + print("Error loading messages: $e"); } } @@ -42,3 +48,52 @@ class ChatMessagesNotifier extends StateNotifier>> { _saveMessages(); } } + +class SelectedLLMNotifier extends StateNotifier { + SelectedLLMNotifier() : super(LLMProvider.ollama) { + _loadSelectedLLM(); + } + + static const _storageKey = 'selectedLLM'; + + Future _loadSelectedLLM() async { + final prefs = await SharedPreferences.getInstance(); + final savedValue = prefs.getString(_storageKey); + if (savedValue != null) { + state = LLMProvider.values.firstWhere( + (e) => e.toString() == savedValue, + orElse: () => LLMProvider.ollama, + ); + } + } + + Future setSelectedLLM(LLMProvider provider) async { + state = provider; + final prefs = await SharedPreferences.getInstance(); + await prefs.setString(_storageKey, provider.toString()); + } +} +final selectedLLMModelProvider = StateNotifierProvider( + (ref) => SelectedLLMModelNotifier(), +); + +class SelectedLLMModelNotifier extends StateNotifier { + SelectedLLMModelNotifier() : super("mistral") { + _loadSelectedLLMModel(); + } + + static const _storageKey = 'selectedLLMModel'; + + Future _loadSelectedLLMModel() async { + final prefs = await SharedPreferences.getInstance(); + state = prefs.getString(_storageKey) ?? "mistral"; + } + + Future setSelectedLLMModel(String model) async { + state = model; + final prefs = await SharedPreferences.getInstance(); + await prefs.setString(_storageKey, model); + } +} + + diff --git a/lib/dashbot/providers/llm_provider.dart b/lib/dashbot/providers/llm_provider.dart new file mode 100644 index 000000000..5a80dc006 --- /dev/null +++ b/lib/dashbot/providers/llm_provider.dart @@ -0,0 +1,22 @@ +class LLMConfig { + final String model; + String? apiUrl; + String? apiKey; + double? temperature; + + LLMConfig({ + required this.model, + this.apiUrl, + this.apiKey, + this.temperature, + }); + + LLMConfig copyWith({String? apiUrl, String? apiKey, String? model, double? temperature}) { + return LLMConfig( + model: model ?? this.model, + apiUrl: apiUrl ?? this.apiUrl, + apiKey: apiKey ?? this.apiKey, + temperature: temperature ?? this.temperature, + ); + } +} diff --git a/lib/dashbot/services/dashbot_service.dart b/lib/dashbot/services/dashbot_service.dart index 8eb0087c8..be1966ccf 100644 --- a/lib/dashbot/services/dashbot_service.dart +++ b/lib/dashbot/services/dashbot_service.dart @@ -1,26 +1,97 @@ import 'package:apidash/dashbot/features/debug.dart'; import 'package:ollama_dart/ollama_dart.dart'; +import 'package:openai_dart/openai_dart.dart'; +import 'package:flutter_gemini/flutter_gemini.dart'; +import '../../consts.dart'; import '../features/explain.dart'; import 'package:apidash/models/request_model.dart'; +import 'package:flutter_riverpod/flutter_riverpod.dart'; + +import '../providers/dashbot_providers.dart'; +import '../providers/llm_provider.dart'; + +final llmConfigProvider = StateNotifierProvider>((ref) { + return LLMConfigNotifier(); +}); + +class LLMConfigNotifier extends StateNotifier> { + LLMConfigNotifier() + : super({ + LLMProvider.ollama: LLMConfig(apiUrl: "http://127.0.0.1:11434/api", model: "mistral"), + LLMProvider.gemini: LLMConfig(apiKey: "gemini_api_key", model: "gemini-1.5"), + LLMProvider.openai: LLMConfig(apiKey: "openAI_api_key", model: "gpt-4-turbo"), + }); + + void updateConfig(LLMProvider provider, LLMConfig newConfig) { + state = {...state, provider: newConfig}; + } +} + +final dashBotServiceProvider = Provider((ref) => DashBotService(ref, LLMProvider.ollama)); class DashBotService { - final OllamaClient _client; - late final ExplainFeature _explainFeature; - late final DebugFeature _debugFeature; + late OllamaClient _ollamaClient; + late OpenAIClient _openAiClient; + late Gemini _geminiClient; + late ExplainFeature _explainFeature; + late DebugFeature _debugFeature; + final Ref _ref; + - DashBotService() - : _client = OllamaClient(baseUrl: 'http://127.0.0.1:11434/api') { + DashBotService(this._ref, LLMProvider selectedModel) { + _initializeClients(); _explainFeature = ExplainFeature(this); _debugFeature = DebugFeature(this); } + void _initializeClients() { + final config = _ref.read(llmConfigProvider); + + _ollamaClient = OllamaClient(baseUrl: config[LLMProvider.ollama]!.apiUrl,); + _openAiClient = OpenAIClient(apiKey: config[LLMProvider.openai]!.apiKey ?? "",); + _geminiClient = Gemini.init(apiKey: config[LLMProvider.gemini]!.apiKey ?? "", ); + } + Future generateResponse(String prompt) async { - final response = await _client.generateCompletion( - request: GenerateCompletionRequest(model: 'llama3.2:3b', prompt: prompt), - ); - return response.response.toString(); + try { + final selectedProvider = _ref.read(selectedLLMProvider); + final config = _ref.read(llmConfigProvider)[selectedProvider]!; + + switch (selectedProvider) { + case LLMProvider.gemini: + final response = await Gemini.instance.chat( + modelName: config.model, + [ + Content(parts: [Part.text(prompt)], role: 'user',), + ]); + return response?.output ?? "Error: No response from Gemini."; + + case LLMProvider.openai: + final response = await _openAiClient.createChatCompletion( + request: CreateChatCompletionRequest( + model: ChatCompletionModel.modelId(config.model), + messages: [ + ChatCompletionMessage.user( + content: ChatCompletionUserMessageContent.string(prompt), + ), + ], + temperature: config.temperature ?? 0.7, + ), + ); + return response.choices.first.message.content ?? "Error: No response from OpenAI."; + + case LLMProvider.ollama: + final response = await _ollamaClient.generateCompletion( + request: GenerateCompletionRequest(model: config.model, prompt: prompt), + ); + return response.response.toString(); + } + } catch (e) { + return "Error: ${e.toString()}"; + } } + Future handleRequest( String input, RequestModel? requestModel, dynamic responseModel) async { if (input == "Explain API") { diff --git a/lib/dashbot/widgets/dashbot_widget.dart b/lib/dashbot/widgets/dashbot_widget.dart index 200d4c5fa..24b4a7ff5 100644 --- a/lib/dashbot/widgets/dashbot_widget.dart +++ b/lib/dashbot/widgets/dashbot_widget.dart @@ -1,8 +1,10 @@ -// lib/dashbot/widgets/dashbot_widget.dart +import 'package:apidash_core/apidash_core.dart'; import 'package:flutter/material.dart'; import 'package:flutter_riverpod/flutter_riverpod.dart'; import 'package:apidash/dashbot/providers/dashbot_providers.dart'; import 'package:apidash/providers/providers.dart'; +import '../../consts.dart'; +import '../services/dashbot_service.dart'; import 'chat_bubble.dart'; class DashBotWidget extends ConsumerStatefulWidget { @@ -94,6 +96,8 @@ class _DashBotWidgetState extends ConsumerState { children: [ _buildHeader(context), const SizedBox(height: 12), + _buildModelSelector(), + const SizedBox(height: 12), _buildQuickActions(showDebugButton), const SizedBox(height: 12), Expanded(child: _buildChatArea(messages)), @@ -104,7 +108,62 @@ class _DashBotWidgetState extends ConsumerState { ), ); } + Widget _buildModelSelector() { + return Row( + children: [ + DropdownButton( + value: ref.watch(selectedLLMProvider), + onChanged: (LLMProvider? newProvider) { + if (newProvider != null) { + ref.read(selectedLLMProvider.notifier).setSelectedLLM(newProvider); + } + }, + items: LLMProvider.values.map((provider) { + return DropdownMenuItem( + value: provider, + child: Text(provider.toString().split('.').last), + ); + }).toList(), + ), + SizedBox(width: 20,), + Consumer(builder: (context, ref, _) { + final selectedProvider = ref.watch(selectedLLMProvider); + final config = ref.watch(llmConfigProvider)[selectedProvider]!; + List models = []; + switch (selectedProvider) { + case LLMProvider.gemini: + models = ["gemini-1.0", "gemini-1.5", "gemini-pro", "gemini-ultra"]; + break; + case LLMProvider.openai: + models = ["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4"]; + break; + case LLMProvider.ollama: + models = ["mistral", "llama2", "codellama", "gemma"]; + break; + } + + return DropdownButton( + value: config.model, + onChanged: (String? newModel) { + if (newModel != null) { + ref.read(llmConfigProvider.notifier).updateConfig( + selectedProvider, + config.copyWith(model: newModel), + ); + } + }, + items: models.map((model) { + return DropdownMenuItem( + value: model, + child: Text(model), + ); + }).toList(), + ); + }), + ], + ); + } Widget _buildHeader(BuildContext context) { return Row( mainAxisAlignment: MainAxisAlignment.spaceBetween, diff --git a/lib/main.dart b/lib/main.dart index d8a4dd7f3..2a8621b65 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -1,5 +1,6 @@ import 'package:apidash_design_system/apidash_design_system.dart'; import 'package:flutter/material.dart'; +import 'package:flutter_gemini/flutter_gemini.dart'; import 'package:flutter_riverpod/flutter_riverpod.dart'; import 'models/models.dart'; import 'providers/providers.dart'; diff --git a/pubspec.lock b/pubspec.lock index 9881558db..0d0c8a662 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -358,6 +358,22 @@ packages: url: "https://pub.dev" source: hosted version: "0.5.0" + dio: + dependency: transitive + description: + name: dio + sha256: "253a18bbd4851fecba42f7343a1df3a9a4c1d31a2c1b37e221086b4fa8c8dbc9" + url: "https://pub.dev" + source: hosted + version: "5.8.0+1" + dio_web_adapter: + dependency: transitive + description: + name: dio_web_adapter + sha256: "7586e476d70caecaf1686d21eee7247ea43ef5c345eab9e0cc3583ff13378d78" + url: "https://pub.dev" + source: hosted + version: "2.1.1" equatable: dependency: transitive description: @@ -528,6 +544,14 @@ packages: description: flutter source: sdk version: "0.0.0" + flutter_gemini: + dependency: "direct main" + description: + name: flutter_gemini + sha256: b7264b1d19acc4b1a5628a0e26c0976aa1fb948f0d3243bc3510ff51e09476b7 + url: "https://pub.dev" + source: hosted + version: "3.0.0" flutter_highlighter: dependency: "direct main" description: @@ -1048,6 +1072,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.2.2+1" + openai_dart: + dependency: "direct main" + description: + name: openai_dart + sha256: "1cc5ed0915fa7572b943de01cfa7a3e5cfe1e6a7f4d0d9a9374d046518e84575" + url: "https://pub.dev" + source: hosted + version: "0.4.5" package_config: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index ebc3ee098..7d7073cd0 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -23,6 +23,7 @@ dependencies: extended_text_field: ^16.0.0 file_selector: ^1.0.3 flex_color_scheme: ^8.1.1 + flutter_gemini: ^3.0.0 flutter_highlighter: ^0.1.0 flutter_hooks: ^0.21.2 flutter_markdown: ^0.7.6+2 @@ -50,6 +51,7 @@ dependencies: multi_trigger_autocomplete_plus: path: packages/multi_trigger_autocomplete_plus ollama_dart: ^0.2.2 + openai_dart: ^0.4.5 package_info_plus: ^8.3.0 path: ^1.8.3 path_provider: ^2.1.2