Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bun.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"react-select": "^5.8.0",
"tauri-plugin-macos-permissions-api": "2.3.0",
"i18next": "^25.7.2",
"immer": "^11.1.3",
"lucide-react": "^0.542.0",
"react": "^18.3.1",
"react-dom": "^18.3.1",
Expand Down
15 changes: 7 additions & 8 deletions src/components/model-selector/DownloadProgressDisplay.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ interface DownloadStats {
}

interface DownloadProgressDisplayProps {
downloadProgress: Map<string, DownloadProgress>;
downloadStats: Map<string, DownloadStats>;
downloadProgress: Record<string, DownloadProgress>;
downloadStats: Record<string, DownloadStats>;
className?: string;
}

Expand All @@ -26,14 +26,13 @@ const DownloadProgressDisplay: React.FC<DownloadProgressDisplayProps> = ({
downloadStats,
className = "",
}) => {
if (downloadProgress.size === 0) {
const progressValues = Object.values(downloadProgress);
if (progressValues.length === 0) {
return null;
}

const progressData: ProgressData[] = Array.from(
downloadProgress.values(),
).map((progress) => {
const stats = downloadStats.get(progress.model_id);
const progressData: ProgressData[] = progressValues.map((progress) => {
const stats = downloadStats[progress.model_id];
return {
id: progress.model_id,
percentage: progress.percentage,
Expand All @@ -45,7 +44,7 @@ const DownloadProgressDisplay: React.FC<DownloadProgressDisplayProps> = ({
<ProgressBar
progress={progressData}
className={className}
showSpeed={downloadProgress.size === 1}
showSpeed={progressValues.length === 1}
size="medium"
/>
);
Expand Down
10 changes: 5 additions & 5 deletions src/components/model-selector/ModelDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ interface DownloadProgress {
interface ModelDropdownProps {
models: ModelInfo[];
currentModelId: string;
downloadProgress: Map<string, DownloadProgress>;
downloadProgress: Record<string, DownloadProgress>;
onModelSelect: (modelId: string) => void;
onModelDownload: (modelId: string) => void;
onModelDelete: (modelId: string) => Promise<void>;
Expand Down Expand Up @@ -52,14 +52,14 @@ const ModelDropdown: React.FC<ModelDropdownProps> = ({
};

const handleModelClick = (modelId: string) => {
if (downloadProgress.has(modelId)) {
if (modelId in downloadProgress) {
return; // Don't allow interaction while downloading
}
onModelSelect(modelId);
};

const handleDownloadClick = (modelId: string) => {
if (downloadProgress.has(modelId)) {
if (modelId in downloadProgress) {
return; // Don't allow interaction while downloading
}
onModelDownload(modelId);
Expand Down Expand Up @@ -158,8 +158,8 @@ const ModelDropdown: React.FC<ModelDropdownProps> = ({
: t("modelSelector.downloadModels")}
</div>
{downloadableModels.map((model) => {
const isDownloading = downloadProgress.has(model.id);
const progress = downloadProgress.get(model.id);
const isDownloading = model.id in downloadProgress;
const progress = downloadProgress[model.id];

return (
<div
Expand Down
158 changes: 82 additions & 76 deletions src/components/model-selector/ModelSelector.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import React, { useState, useRef, useEffect } from "react";
import { useTranslation } from "react-i18next";
import { listen } from "@tauri-apps/api/event";
import { produce } from "immer";
import { commands, type ModelInfo } from "@/bindings";
import { getTranslatedModelName } from "../../lib/utils/modelTranslation";
import ModelStatusButton from "./ModelStatusButton";
Expand Down Expand Up @@ -48,15 +49,15 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => {
const [modelStatus, setModelStatus] = useState<ModelStatus>("unloaded");
const [modelError, setModelError] = useState<string | null>(null);
const [modelDownloadProgress, setModelDownloadProgress] = useState<
Map<string, DownloadProgress>
>(new Map());
Record<string, DownloadProgress>
>({});
const [showModelDropdown, setShowModelDropdown] = useState(false);
const [downloadStats, setDownloadStats] = useState<
Map<string, DownloadStats>
>(new Map());
const [extractingModels, setExtractingModels] = useState<Set<string>>(
new Set(),
);
Record<string, DownloadStats>
>({});
const [extractingModels, setExtractingModels] = useState<
Record<string, true>
>({});

const dropdownRef = useRef<HTMLDivElement>(null);

Expand Down Expand Up @@ -97,53 +98,52 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => {
"model-download-progress",
(event) => {
const progress = event.payload;
setModelDownloadProgress((prev) => {
const newMap = new Map(prev);
newMap.set(progress.model_id, progress);
return newMap;
});
setModelDownloadProgress(
produce((downloadProgress) => {
downloadProgress[progress.model_id] = progress;
}),
);
setModelStatus("downloading");

// Update download stats for speed calculation
const now = Date.now();
setDownloadStats((prev) => {
const current = prev.get(progress.model_id);
const newStats = new Map(prev);

if (!current) {
// First progress update - initialize
newStats.set(progress.model_id, {
startTime: now,
lastUpdate: now,
totalDownloaded: progress.downloaded,
speed: 0,
});
} else {
// Calculate speed over last few seconds
const timeDiff = (now - current.lastUpdate) / 1000; // seconds
const bytesDiff = progress.downloaded - current.totalDownloaded;

if (timeDiff > 0.5) {
// Update speed every 500ms
const currentSpeed = bytesDiff / (1024 * 1024) / timeDiff; // MB/s
// Smooth the speed with exponential moving average, but ensure positive values
const validCurrentSpeed = Math.max(0, currentSpeed);
const smoothedSpeed =
current.speed > 0
? current.speed * 0.8 + validCurrentSpeed * 0.2
: validCurrentSpeed;

newStats.set(progress.model_id, {
startTime: current.startTime,
setDownloadStats(
produce((stats) => {
const current = stats[progress.model_id];

if (!current) {
// First progress update - initialize
stats[progress.model_id] = {
startTime: now,
lastUpdate: now,
totalDownloaded: progress.downloaded,
speed: Math.max(0, smoothedSpeed),
});
speed: 0,
};
} else {
// Calculate speed over last few seconds
const timeDiff = (now - current.lastUpdate) / 1000; // seconds
const bytesDiff = progress.downloaded - current.totalDownloaded;

if (timeDiff > 0.5) {
// Update speed every 500ms
const currentSpeed = bytesDiff / (1024 * 1024) / timeDiff; // MB/s
// Smooth the speed with exponential moving average, but ensure positive values
const validCurrentSpeed = Math.max(0, currentSpeed);
const smoothedSpeed =
current.speed > 0
? current.speed * 0.8 + validCurrentSpeed * 0.2
: validCurrentSpeed;

stats[progress.model_id] = {
startTime: current.startTime,
lastUpdate: now,
totalDownloaded: progress.downloaded,
speed: Math.max(0, smoothedSpeed),
};
}
}
}

return newStats;
});
}),
);
},
);

Expand All @@ -152,16 +152,16 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => {
"model-download-complete",
(event) => {
const modelId = event.payload;
setModelDownloadProgress((prev) => {
const newMap = new Map(prev);
newMap.delete(modelId);
return newMap;
});
setDownloadStats((prev) => {
const newStats = new Map(prev);
newStats.delete(modelId);
return newStats;
});
setModelDownloadProgress(
produce((progress) => {
delete progress[modelId];
}),
);
setDownloadStats(
produce((stats) => {
delete stats[modelId];
}),
);
loadModels(); // Refresh models list

// Auto-select the newly downloaded model (skip if recording in progress)
Expand All @@ -181,7 +181,11 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => {
"model-extraction-started",
(event) => {
const modelId = event.payload;
setExtractingModels((prev) => new Set(prev.add(modelId)));
Copy link
Contributor Author

@joshribakoff joshribakoff Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The "correct" old pattern would have been:

const myNewSet = new Set(myOldSet)
// mutate myNewSet

🔴 But instead, the old code did this:

// mutate myOldSet
return new Set(myOldSet)

This is illustrative. This isn't stylistic. This is correctness.

setExtractingModels(
produce((extracting) => {
extracting[modelId] = true;
}),
);
setModelStatus("extracting");
},
);
Expand All @@ -190,11 +194,11 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => {
"model-extraction-completed",
(event) => {
const modelId = event.payload;
setExtractingModels((prev) => {
const next = new Set(prev);
next.delete(modelId);
return next;
});
setExtractingModels(
produce((extracting) => {
delete extracting[modelId];
}),
);
loadModels(); // Refresh models list

// Auto-select the newly extracted model (skip if recording in progress)
Expand All @@ -214,11 +218,11 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => {
error: string;
}>("model-extraction-failed", (event) => {
const modelId = event.payload.model_id;
setExtractingModels((prev) => {
const next = new Set(prev);
next.delete(modelId);
return next;
});
setExtractingModels(
produce((extracting) => {
delete extracting[modelId];
}),
);
setModelError(`Failed to extract model: ${event.payload.error}`);
setModelStatus("error");
});
Expand Down Expand Up @@ -329,32 +333,34 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => {
};

const getModelDisplayText = (): string => {
if (extractingModels.size > 0) {
if (extractingModels.size === 1) {
const [modelId] = Array.from(extractingModels);
const extractingKeys = Object.keys(extractingModels);
Copy link
Contributor Author

@joshribakoff joshribakoff Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only place where I feel I hurt readability, instead of improving it, but I find the original code to not be the most readable. I can make more PRs. We should add the tests first.

if (extractingKeys.length > 0) {
if (extractingKeys.length === 1) {
const modelId = extractingKeys[0];
const model = models.find((m) => m.id === modelId);
const modelName = model
? getTranslatedModelName(model, t)
: t("modelSelector.extractingGeneric").replace("...", "");
return t("modelSelector.extracting", { modelName });
} else {
return t("modelSelector.extractingMultiple", {
count: extractingModels.size,
count: extractingKeys.length,
});
}
}

if (modelDownloadProgress.size > 0) {
if (modelDownloadProgress.size === 1) {
const [progress] = Array.from(modelDownloadProgress.values());
const progressValues = Object.values(modelDownloadProgress);
if (progressValues.length > 0) {
if (progressValues.length === 1) {
const progress = progressValues[0];
const percentage = Math.max(
0,
Math.min(100, Math.round(progress.percentage)),
);
return t("modelSelector.downloading", { percentage });
} else {
return t("modelSelector.downloadingMultiple", {
count: modelDownloadProgress.size,
count: progressValues.length,
});
}
}
Expand Down