-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
refactor: use Immer for immutable state updates #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import React, { useState, useRef, useEffect } from "react"; | ||
| import { useTranslation } from "react-i18next"; | ||
| import { listen } from "@tauri-apps/api/event"; | ||
| import { produce } from "immer"; | ||
| import { commands, type ModelInfo } from "@/bindings"; | ||
| import { getTranslatedModelName } from "../../lib/utils/modelTranslation"; | ||
| import ModelStatusButton from "./ModelStatusButton"; | ||
|
|
@@ -48,15 +49,15 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => { | |
| const [modelStatus, setModelStatus] = useState<ModelStatus>("unloaded"); | ||
| const [modelError, setModelError] = useState<string | null>(null); | ||
| const [modelDownloadProgress, setModelDownloadProgress] = useState< | ||
| Map<string, DownloadProgress> | ||
| >(new Map()); | ||
| Record<string, DownloadProgress> | ||
| >({}); | ||
| const [showModelDropdown, setShowModelDropdown] = useState(false); | ||
| const [downloadStats, setDownloadStats] = useState< | ||
| Map<string, DownloadStats> | ||
| >(new Map()); | ||
| const [extractingModels, setExtractingModels] = useState<Set<string>>( | ||
| new Set(), | ||
| ); | ||
| Record<string, DownloadStats> | ||
| >({}); | ||
| const [extractingModels, setExtractingModels] = useState< | ||
| Record<string, true> | ||
| >({}); | ||
|
|
||
| const dropdownRef = useRef<HTMLDivElement>(null); | ||
|
|
||
|
|
@@ -97,53 +98,52 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => { | |
| "model-download-progress", | ||
| (event) => { | ||
| const progress = event.payload; | ||
| setModelDownloadProgress((prev) => { | ||
| const newMap = new Map(prev); | ||
| newMap.set(progress.model_id, progress); | ||
| return newMap; | ||
| }); | ||
| setModelDownloadProgress( | ||
| produce((downloadProgress) => { | ||
| downloadProgress[progress.model_id] = progress; | ||
| }), | ||
| ); | ||
| setModelStatus("downloading"); | ||
|
|
||
| // Update download stats for speed calculation | ||
| const now = Date.now(); | ||
| setDownloadStats((prev) => { | ||
| const current = prev.get(progress.model_id); | ||
| const newStats = new Map(prev); | ||
|
|
||
| if (!current) { | ||
| // First progress update - initialize | ||
| newStats.set(progress.model_id, { | ||
| startTime: now, | ||
| lastUpdate: now, | ||
| totalDownloaded: progress.downloaded, | ||
| speed: 0, | ||
| }); | ||
| } else { | ||
| // Calculate speed over last few seconds | ||
| const timeDiff = (now - current.lastUpdate) / 1000; // seconds | ||
| const bytesDiff = progress.downloaded - current.totalDownloaded; | ||
|
|
||
| if (timeDiff > 0.5) { | ||
| // Update speed every 500ms | ||
| const currentSpeed = bytesDiff / (1024 * 1024) / timeDiff; // MB/s | ||
| // Smooth the speed with exponential moving average, but ensure positive values | ||
| const validCurrentSpeed = Math.max(0, currentSpeed); | ||
| const smoothedSpeed = | ||
| current.speed > 0 | ||
| ? current.speed * 0.8 + validCurrentSpeed * 0.2 | ||
| : validCurrentSpeed; | ||
|
|
||
| newStats.set(progress.model_id, { | ||
| startTime: current.startTime, | ||
| setDownloadStats( | ||
| produce((stats) => { | ||
| const current = stats[progress.model_id]; | ||
|
|
||
| if (!current) { | ||
| // First progress update - initialize | ||
| stats[progress.model_id] = { | ||
| startTime: now, | ||
| lastUpdate: now, | ||
| totalDownloaded: progress.downloaded, | ||
| speed: Math.max(0, smoothedSpeed), | ||
| }); | ||
| speed: 0, | ||
| }; | ||
| } else { | ||
| // Calculate speed over last few seconds | ||
| const timeDiff = (now - current.lastUpdate) / 1000; // seconds | ||
| const bytesDiff = progress.downloaded - current.totalDownloaded; | ||
|
|
||
| if (timeDiff > 0.5) { | ||
| // Update speed every 500ms | ||
| const currentSpeed = bytesDiff / (1024 * 1024) / timeDiff; // MB/s | ||
| // Smooth the speed with exponential moving average, but ensure positive values | ||
| const validCurrentSpeed = Math.max(0, currentSpeed); | ||
| const smoothedSpeed = | ||
| current.speed > 0 | ||
| ? current.speed * 0.8 + validCurrentSpeed * 0.2 | ||
| : validCurrentSpeed; | ||
|
|
||
| stats[progress.model_id] = { | ||
| startTime: current.startTime, | ||
| lastUpdate: now, | ||
| totalDownloaded: progress.downloaded, | ||
| speed: Math.max(0, smoothedSpeed), | ||
| }; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return newStats; | ||
| }); | ||
| }), | ||
| ); | ||
| }, | ||
| ); | ||
|
|
||
|
|
@@ -152,16 +152,16 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => { | |
| "model-download-complete", | ||
| (event) => { | ||
| const modelId = event.payload; | ||
| setModelDownloadProgress((prev) => { | ||
| const newMap = new Map(prev); | ||
| newMap.delete(modelId); | ||
| return newMap; | ||
| }); | ||
| setDownloadStats((prev) => { | ||
| const newStats = new Map(prev); | ||
| newStats.delete(modelId); | ||
| return newStats; | ||
| }); | ||
| setModelDownloadProgress( | ||
| produce((progress) => { | ||
| delete progress[modelId]; | ||
| }), | ||
| ); | ||
| setDownloadStats( | ||
| produce((stats) => { | ||
| delete stats[modelId]; | ||
| }), | ||
| ); | ||
| loadModels(); // Refresh models list | ||
|
|
||
| // Auto-select the newly downloaded model (skip if recording in progress) | ||
|
|
@@ -181,7 +181,11 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => { | |
| "model-extraction-started", | ||
| (event) => { | ||
| const modelId = event.payload; | ||
| setExtractingModels((prev) => new Set(prev.add(modelId))); | ||
| setExtractingModels( | ||
| produce((extracting) => { | ||
| extracting[modelId] = true; | ||
| }), | ||
| ); | ||
| setModelStatus("extracting"); | ||
| }, | ||
| ); | ||
|
|
@@ -190,11 +194,11 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => { | |
| "model-extraction-completed", | ||
| (event) => { | ||
| const modelId = event.payload; | ||
| setExtractingModels((prev) => { | ||
| const next = new Set(prev); | ||
| next.delete(modelId); | ||
| return next; | ||
| }); | ||
| setExtractingModels( | ||
| produce((extracting) => { | ||
| delete extracting[modelId]; | ||
| }), | ||
| ); | ||
| loadModels(); // Refresh models list | ||
|
|
||
| // Auto-select the newly extracted model (skip if recording in progress) | ||
|
|
@@ -214,11 +218,11 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => { | |
| error: string; | ||
| }>("model-extraction-failed", (event) => { | ||
| const modelId = event.payload.model_id; | ||
| setExtractingModels((prev) => { | ||
| const next = new Set(prev); | ||
| next.delete(modelId); | ||
| return next; | ||
| }); | ||
| setExtractingModels( | ||
| produce((extracting) => { | ||
| delete extracting[modelId]; | ||
| }), | ||
| ); | ||
| setModelError(`Failed to extract model: ${event.payload.error}`); | ||
| setModelStatus("error"); | ||
| }); | ||
|
|
@@ -329,32 +333,34 @@ const ModelSelector: React.FC<ModelSelectorProps> = ({ onError }) => { | |
| }; | ||
|
|
||
| const getModelDisplayText = (): string => { | ||
| if (extractingModels.size > 0) { | ||
| if (extractingModels.size === 1) { | ||
| const [modelId] = Array.from(extractingModels); | ||
| const extractingKeys = Object.keys(extractingModels); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the only place where I feel I hurt readability, instead of improving it, but I find the original code to not be the most readable. I can make more PRs. We should add the tests first. |
||
| if (extractingKeys.length > 0) { | ||
| if (extractingKeys.length === 1) { | ||
| const modelId = extractingKeys[0]; | ||
| const model = models.find((m) => m.id === modelId); | ||
| const modelName = model | ||
| ? getTranslatedModelName(model, t) | ||
| : t("modelSelector.extractingGeneric").replace("...", ""); | ||
| return t("modelSelector.extracting", { modelName }); | ||
| } else { | ||
| return t("modelSelector.extractingMultiple", { | ||
| count: extractingModels.size, | ||
| count: extractingKeys.length, | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| if (modelDownloadProgress.size > 0) { | ||
| if (modelDownloadProgress.size === 1) { | ||
| const [progress] = Array.from(modelDownloadProgress.values()); | ||
| const progressValues = Object.values(modelDownloadProgress); | ||
| if (progressValues.length > 0) { | ||
| if (progressValues.length === 1) { | ||
| const progress = progressValues[0]; | ||
| const percentage = Math.max( | ||
| 0, | ||
| Math.min(100, Math.round(progress.percentage)), | ||
| ); | ||
| return t("modelSelector.downloading", { percentage }); | ||
| } else { | ||
| return t("modelSelector.downloadingMultiple", { | ||
| count: modelDownloadProgress.size, | ||
| count: progressValues.length, | ||
| }); | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 The "correct" old pattern would have been:
🔴 But instead, the old code did this:
This is illustrative. This isn't stylistic. This is correctness.