Skip to content

Commit 3d4a88e

Browse files
authored
[OPIK-3263] [FE] Add model selector for OpikAI features (Prompt Generator, Prompt Improver) (#4572)
* add a new hook for model selection and reuse it in dataset expansion and prompt improvement * label style * review fix * review fixes - Sasha * review fix
1 parent 28b0070 commit 3d4a88e

File tree

3 files changed

+194
-66
lines changed

3 files changed

+194
-66
lines changed

apps/opik-frontend/src/components/pages/DatasetItemsPage/DatasetExpansionDialog.tsx

Lines changed: 6 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ import TooltipWrapper from "@/components/shared/TooltipWrapper/TooltipWrapper";
3131
import useDatasetExpansionMutation from "@/api/datasets/useDatasetExpansionMutation";
3232
import useDatasetItemsList from "@/api/datasets/useDatasetItemsList";
3333
import useAppStore from "@/store/AppStore";
34-
import useLastPickedModel from "@/hooks/useLastPickedModel";
35-
import useLLMProviderModelsData from "@/hooks/useLLMProviderModelsData";
36-
import useProviderKeys from "@/api/provider-keys/useProviderKeys";
34+
import useModelSelection from "@/hooks/useModelSelection";
3735
import useProgressSimulation from "@/hooks/useProgressSimulation";
3836

3937
const DATASET_EXPANSION_PROGRESS_MESSAGES = [
@@ -43,7 +41,6 @@ const DATASET_EXPANSION_PROGRESS_MESSAGES = [
4341
"Finalizing generated data...",
4442
];
4543
import { DatasetExpansionRequest, DatasetItem } from "@/types/datasets";
46-
import { COMPOSED_PROVIDER_TYPE, PROVIDER_MODEL_TYPE } from "@/types/providers";
4744

4845
const DATASET_EXPANSION_LAST_PICKED_MODEL = "opik-dataset-expansion-model";
4946
const SAMPLE_COUNT_MIN = 1;
@@ -65,57 +62,11 @@ const DatasetExpansionDialog: React.FunctionComponent<
6562
> = ({ datasetId: initialDatasetId, open, setOpen, onSamplesGenerated }) => {
6663
const workspaceName = useAppStore((state) => state.activeWorkspaceName);
6764

68-
const [lastPickedModel, setLastPickedModel] = useLastPickedModel({
69-
key: DATASET_EXPANSION_LAST_PICKED_MODEL,
65+
// Model selection with persistence using the reusable hook
66+
const { model, modelSelectProps } = useModelSelection({
67+
persistenceKey: DATASET_EXPANSION_LAST_PICKED_MODEL,
7068
});
7169

72-
const { data: providerKeysData } = useProviderKeys({
73-
workspaceName,
74-
});
75-
76-
const providerKeys = useMemo(() => {
77-
return providerKeysData?.content?.map((c) => c.ui_composed_provider) || [];
78-
}, [providerKeysData]);
79-
80-
const { calculateModelProvider, calculateDefaultModel } =
81-
useLLMProviderModelsData();
82-
83-
const { model, provider } = useMemo(() => {
84-
const calculatedModel = calculateDefaultModel(
85-
lastPickedModel,
86-
providerKeys,
87-
) as PROVIDER_MODEL_TYPE;
88-
const calculatedProvider = calculateModelProvider(calculatedModel);
89-
return {
90-
model: calculatedModel,
91-
provider: calculatedProvider,
92-
};
93-
}, [
94-
calculateDefaultModel,
95-
calculateModelProvider,
96-
lastPickedModel,
97-
providerKeys,
98-
]);
99-
100-
const handleAddProvider = useCallback(
101-
(provider: COMPOSED_PROVIDER_TYPE) => {
102-
if (!model) {
103-
setLastPickedModel(calculateDefaultModel(model, [provider], provider));
104-
}
105-
},
106-
[calculateDefaultModel, model, setLastPickedModel],
107-
);
108-
109-
const handleDeleteProvider = useCallback(
110-
(provider: COMPOSED_PROVIDER_TYPE) => {
111-
const currentProvider = calculateModelProvider(model);
112-
if (currentProvider === provider) {
113-
setLastPickedModel("");
114-
}
115-
},
116-
[calculateModelProvider, model, setLastPickedModel],
117-
);
118-
11970
const [sampleCount, setSampleCount] = useState<number>(5);
12071
const [variationInstructions, setVariationInstructions] =
12172
useState<string>("");
@@ -325,11 +276,6 @@ const DatasetExpansionDialog: React.FunctionComponent<
325276
complete,
326277
]);
327278

328-
const handleModelChange = useCallback(
329-
(model: PROVIDER_MODEL_TYPE) => setLastPickedModel(model),
330-
[setLastPickedModel],
331-
);
332-
333279
return (
334280
<Dialog open={open} onOpenChange={setOpen}>
335281
<DialogContent className="max-w-2xl">
@@ -539,12 +485,9 @@ const DatasetExpansionDialog: React.FunctionComponent<
539485
<div className="mt-6 space-y-2">
540486
<Label htmlFor="model">Model</Label>
541487
<PromptModelSelect
542-
value={model}
543488
workspaceName={workspaceName}
544-
onChange={handleModelChange}
545-
provider={provider}
546-
onAddProvider={handleAddProvider}
547-
onDeleteProvider={handleDeleteProvider}
489+
{...modelSelectProps}
490+
disabled={isPending}
548491
/>
549492
</div>
550493

apps/opik-frontend/src/components/shared/PromptImprovementDialog/PromptImprovementDialog.tsx

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import { Textarea } from "@/components/ui/textarea";
2222
import { Description } from "@/components/ui/description";
2323
import { Separator } from "@/components/ui/separator";
2424
import ExplainerDescription from "@/components/shared/ExplainerDescription/ExplainerDescription";
25+
import PromptModelSelect from "@/components/pages-shared/llm/PromptModelSelect/PromptModelSelect";
2526
import usePromptImprovement from "@/hooks/usePromptImprovement";
2627
import useProgressSimulation from "@/hooks/useProgressSimulation";
28+
import useModelSelection from "@/hooks/useModelSelection";
2729
import {
2830
COMPOSED_PROVIDER_TYPE,
2931
LLMPromptConfigsType,
@@ -49,6 +51,8 @@ const PROMPT_IMPROVEMENT_PROGRESS_MESSAGES = [
4951
"Polishing the output...",
5052
];
5153

54+
const PROMPT_IMPROVEMENT_LAST_PICKED_MODEL = "opik-prompt-improvement-model";
55+
5256
interface PromptImprovementDialogProps {
5357
open: boolean;
5458
setOpen: (open: boolean) => void;
@@ -66,9 +70,9 @@ const PromptImprovementDialog: React.FC<PromptImprovementDialogProps> = ({
6670
setOpen,
6771
id,
6872
originalPrompt = "",
69-
model,
70-
provider,
71-
configs,
73+
model: defaultModel,
74+
provider: defaultProvider,
75+
configs: defaultConfigs,
7276
workspaceName,
7377
onAccept,
7478
}) => {
@@ -79,6 +83,14 @@ const PromptImprovementDialog: React.FC<PromptImprovementDialogProps> = ({
7983
const [isEditorFocused, setIsEditorFocused] = useState(false);
8084
const editorViewRef = useRef<EditorView | null>(null);
8185

86+
// Model selection with persistence using the reusable hook
87+
const { model, provider, configs, modelSelectProps } = useModelSelection({
88+
persistenceKey: PROMPT_IMPROVEMENT_LAST_PICKED_MODEL,
89+
defaultModel,
90+
defaultProvider,
91+
defaultConfigs,
92+
});
93+
8294
const { improvePrompt, generatePrompt } = usePromptImprovement({
8395
workspaceName,
8496
});
@@ -443,6 +455,16 @@ const PromptImprovementDialog: React.FC<PromptImprovementDialogProps> = ({
443455
/>
444456
</DialogHeader>
445457
<DialogAutoScrollBody>
458+
<div className="mb-4 flex items-center gap-3">
459+
<div className="comet-body-accented shrink-0">Model</div>
460+
<div className="w-64">
461+
<PromptModelSelect
462+
workspaceName={workspaceName}
463+
{...modelSelectProps}
464+
disabled={isLoading}
465+
/>
466+
</div>
467+
</div>
446468
{error && (
447469
<Alert variant="destructive" className="mb-4">
448470
<AlertTitle>{error}</AlertTitle>
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import { useCallback, useMemo } from "react";
2+
import useLastPickedModel from "@/hooks/useLastPickedModel";
3+
import useLLMProviderModelsData from "@/hooks/useLLMProviderModelsData";
4+
import useProviderKeys from "@/api/provider-keys/useProviderKeys";
5+
import useAppStore from "@/store/AppStore";
6+
import { getDefaultConfigByProvider } from "@/lib/playground";
7+
import {
8+
COMPOSED_PROVIDER_TYPE,
9+
LLMPromptConfigsType,
10+
PROVIDER_MODEL_TYPE,
11+
} from "@/types/providers";
12+
13+
export interface UseModelSelectionParams {
14+
persistenceKey: string;
15+
defaultModel?: string;
16+
defaultProvider?: COMPOSED_PROVIDER_TYPE | "";
17+
defaultConfigs?: LLMPromptConfigsType;
18+
}
19+
20+
export interface ModelSelectProps {
21+
value: PROVIDER_MODEL_TYPE | "";
22+
provider: COMPOSED_PROVIDER_TYPE | "";
23+
onChange: (
24+
model: PROVIDER_MODEL_TYPE,
25+
provider: COMPOSED_PROVIDER_TYPE,
26+
) => void;
27+
onAddProvider: (provider: COMPOSED_PROVIDER_TYPE) => void;
28+
onDeleteProvider: (provider: COMPOSED_PROVIDER_TYPE) => void;
29+
}
30+
31+
export interface UseModelSelectionResult {
32+
model: PROVIDER_MODEL_TYPE | "";
33+
provider: COMPOSED_PROVIDER_TYPE | "";
34+
configs: LLMPromptConfigsType;
35+
modelSelectProps: ModelSelectProps;
36+
}
37+
38+
const useModelSelection = ({
39+
persistenceKey,
40+
defaultModel,
41+
defaultProvider,
42+
defaultConfigs,
43+
}: UseModelSelectionParams): UseModelSelectionResult => {
44+
const workspaceName = useAppStore((state) => state.activeWorkspaceName);
45+
46+
const [lastPickedModel, setLastPickedModel] = useLastPickedModel({
47+
key: persistenceKey,
48+
});
49+
50+
const { data: providerKeysData } = useProviderKeys({
51+
workspaceName,
52+
});
53+
54+
const providerKeys = useMemo(() => {
55+
return providerKeysData?.content?.map((c) => c.ui_composed_provider) || [];
56+
}, [providerKeysData]);
57+
58+
const { calculateModelProvider, calculateDefaultModel } =
59+
useLLMProviderModelsData();
60+
61+
const { model, provider, configs } = useMemo(() => {
62+
if (lastPickedModel) {
63+
const lastPickedProvider = calculateModelProvider(lastPickedModel);
64+
if (lastPickedProvider && providerKeys.includes(lastPickedProvider)) {
65+
return {
66+
model: lastPickedModel,
67+
provider: lastPickedProvider,
68+
configs: getDefaultConfigByProvider(
69+
lastPickedProvider,
70+
lastPickedModel,
71+
),
72+
};
73+
}
74+
}
75+
76+
if (defaultModel && defaultProvider) {
77+
return {
78+
model: defaultModel as PROVIDER_MODEL_TYPE | "",
79+
provider: defaultProvider,
80+
configs: defaultConfigs ?? getDefaultConfigByProvider(defaultProvider),
81+
};
82+
}
83+
84+
const calculatedModel = calculateDefaultModel(
85+
lastPickedModel,
86+
providerKeys,
87+
) as PROVIDER_MODEL_TYPE | "";
88+
const calculatedProvider = calculateModelProvider(calculatedModel);
89+
return {
90+
model: calculatedModel,
91+
provider: calculatedProvider,
92+
configs: getDefaultConfigByProvider(calculatedProvider, calculatedModel),
93+
};
94+
}, [
95+
lastPickedModel,
96+
providerKeys,
97+
calculateModelProvider,
98+
calculateDefaultModel,
99+
defaultModel,
100+
defaultProvider,
101+
defaultConfigs,
102+
]);
103+
104+
const handleModelChange = useCallback(
105+
(newModel: PROVIDER_MODEL_TYPE) => {
106+
setLastPickedModel(newModel);
107+
},
108+
[setLastPickedModel],
109+
);
110+
111+
const handleAddProvider = useCallback(
112+
(addedProvider: COMPOSED_PROVIDER_TYPE) => {
113+
if (!model) {
114+
setLastPickedModel(
115+
calculateDefaultModel(
116+
model as PROVIDER_MODEL_TYPE | "",
117+
[addedProvider],
118+
addedProvider,
119+
),
120+
);
121+
}
122+
},
123+
[calculateDefaultModel, model, setLastPickedModel],
124+
);
125+
126+
const handleDeleteProvider = useCallback(
127+
(deletedProvider: COMPOSED_PROVIDER_TYPE) => {
128+
const currentProvider = calculateModelProvider(
129+
model as PROVIDER_MODEL_TYPE | "",
130+
);
131+
if (currentProvider === deletedProvider) {
132+
setLastPickedModel("");
133+
}
134+
},
135+
[calculateModelProvider, model, setLastPickedModel],
136+
);
137+
138+
const modelSelectProps: ModelSelectProps = useMemo(
139+
() => ({
140+
value: model as PROVIDER_MODEL_TYPE | "",
141+
provider,
142+
onChange: handleModelChange,
143+
onAddProvider: handleAddProvider,
144+
onDeleteProvider: handleDeleteProvider,
145+
}),
146+
[
147+
model,
148+
provider,
149+
handleModelChange,
150+
handleAddProvider,
151+
handleDeleteProvider,
152+
],
153+
);
154+
155+
return {
156+
model: model as PROVIDER_MODEL_TYPE | "",
157+
provider,
158+
configs,
159+
modelSelectProps,
160+
};
161+
};
162+
163+
export default useModelSelection;

0 commit comments

Comments
 (0)