Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion symbolic_regression/src/search_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ where
progress_finished: false,
};

let _ = core.step(dataset, baseline_loss, options, &controller, usize::MAX);
let _ = core.step(dataset, baseline_loss, options, &controller, false, usize::MAX);
core.finish_progress_if_needed();

let SearchCore { hall, pools, .. } = core;
Expand Down Expand Up @@ -233,6 +233,7 @@ where
baseline_loss: Option<T>,
options: &Options<T, D>,
controller: &StopController,
force_single_thread: bool,
n_cycles: usize,
) -> usize {
if n_cycles == 0 {
Expand All @@ -246,6 +247,88 @@ where
return 0;
}

if force_single_thread {
let mut completed_total = 0usize;
while completed_total < n_cycles {
if is_finished(&self.counters) {
break;
}

if controller.should_stop(self.pools.total_evals) {
controller.cancel();
break;
}

if self.next_task >= self.task_order.len() {
self.prepare_iteration_state(options.niterations);
if self.next_task >= self.task_order.len() {
break;
}
}

let mut dispatched = false;
while !dispatched && completed_total < n_cycles && self.next_task < self.task_order.len() {
if is_finished(&self.counters) {
break;
}

if controller.should_stop(self.pools.total_evals) {
controller.cancel();
break;
}

let pop_idx = self.task_order[self.next_task];
self.next_task += 1;

let Some(pop_state) = self.pools.pops[pop_idx].take() else {
continue;
};

let cycles_remaining_start = self.counters.cycles_remaining_start_for_next_dispatch();
let curmaxsize =
warmup::get_cur_maxsize(options, self.counters.total_cycles, cycles_remaining_start);
let mut stats_snapshot = self.stats.clone();
stats_snapshot.normalize();

let full_dataset = TaggedDataset::new(dataset, baseline_loss);
let res = execute_task(
full_dataset,
options,
pop_idx,
curmaxsize,
stats_snapshot,
pop_state,
controller,
);

apply_task_result(
options,
&mut self.counters,
&mut self.stats,
&mut self.hall,
&mut self.progress,
&mut self.pools,
res,
);
completed_total += 1;
dispatched = true;

if controller.should_stop(self.pools.total_evals) {
controller.cancel();
}
}

if !dispatched {
break;
}
}

if is_finished(&self.counters) {
self.finish_progress_if_needed();
}
return completed_total;
}

let usable_threads = usable_rayon_threads();
let need_inline = usable_threads == 0;
let n_workers = usable_threads.min(self.pools.pops.len()).max(1);
Expand Down Expand Up @@ -376,6 +459,7 @@ pub struct SearchEngine<T: Float + AddAssign, Ops, const D: usize> {
dataset: Dataset<T>,
baseline_loss: Option<T>,
options: Options<T, D>,
force_single_thread: bool,
controller: StopController,
core: SearchCore<T, Ops, D>,
}
Expand Down Expand Up @@ -427,11 +511,20 @@ where
dataset,
baseline_loss,
options,
force_single_thread: false,
controller,
core,
}
}

pub fn set_parallelism(&mut self, enabled: bool) {
self.force_single_thread = !enabled;
}

pub fn parallelism_enabled(&self) -> bool {
!self.force_single_thread
}

pub fn total_cycles(&self) -> usize {
self.core.counters.total_cycles
}
Expand Down Expand Up @@ -474,6 +567,7 @@ where
self.baseline_loss,
&self.options,
&self.controller,
self.force_single_thread,
n_cycles,
)
}
Expand Down
10 changes: 5 additions & 5 deletions web/ui/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion web/ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
"dependencies": {
"papaparse": "^5.5.3",
"plotly.js-dist-min": "^2.35.2",
"plotly.js-basic-dist-min": "2.35.2",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-plotly.js": "^2.6.0",
Expand Down
2 changes: 1 addition & 1 deletion web/ui/src/app/panes/EnterData.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, { useEffect, useMemo, useRef, useState } from "react";
import Plot from "react-plotly.js";
import Plot from "@/plotly/Plot";
import { useSessionStore } from "../../state/sessionStore";
import { formatSci, plotLayoutBase, usePrefersDark } from "./searchSolutions/plotUtils";

Expand Down
5 changes: 5 additions & 0 deletions web/ui/src/app/panes/SearchSolutions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ export function SearchSolutions(): React.ReactElement {
const c = useSearchController(SrWorkerClient);
const options = useSessionStore((s) => s.options);
const setOptionsPatch = useSessionStore((s) => s.setOptionsPatch);
const threadsEnabled = useSessionStore((s) => s.threadsEnabled);
const setThreadsEnabled = useSessionStore((s) => s.setThreadsEnabled);

return (
<div className="pane">
Expand All @@ -26,6 +28,9 @@ export function SearchSolutions(): React.ReactElement {
niterations={options?.niterations ?? null}
setNiterations={(n) => setOptionsPatch({ niterations: n })}
canEditNiterations={c.runtime.status === "idle" || c.runtime.status === "error"}
threadsEnabled={threadsEnabled}
setThreadsEnabled={setThreadsEnabled}
canEditThreadsEnabled={c.runtime.status === "idle" || c.runtime.status === "error"}
initSearch={c.initSearch}
start={c.start}
pause={c.pause}
Expand Down
15 changes: 15 additions & 0 deletions web/ui/src/app/panes/searchSolutions/ControlsCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ export function ControlsCard(props: {
setNiterations: (n: number) => void;
canEditNiterations: boolean;

threadsEnabled: boolean;
setThreadsEnabled: (enabled: boolean) => void;
canEditThreadsEnabled: boolean;

initSearch: () => void;
start: () => void;
pause: () => void;
Expand Down Expand Up @@ -48,6 +52,17 @@ export function ControlsCard(props: {
/>
</label>

<label className="toolbarField">
<span className="label">threads</span>
<input
type="checkbox"
checked={props.threadsEnabled}
disabled={!props.canEditThreadsEnabled}
onChange={(e) => props.setThreadsEnabled(e.target.checked)}
data-testid="threads-enabled"
/>
</label>

<div className="spacer" />

<div className="statusLine">
Expand Down
2 changes: 1 addition & 1 deletion web/ui/src/app/panes/searchSolutions/FitPlot.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React from "react";
import Plot from "react-plotly.js";
import Plot from "@/plotly/Plot";
import type { FitPlotMode } from "./types";
import { plotLayoutBase, sortXY } from "./plotUtils";

Expand Down
2 changes: 1 addition & 1 deletion web/ui/src/app/panes/searchSolutions/ParetoPlotCard.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React from "react";
import Plot from "react-plotly.js";
import Plot from "@/plotly/Plot";
import type { EquationSummary } from "../../../types/srTypes";
import { plotLayoutBase } from "./plotUtils";

Expand Down
7 changes: 5 additions & 2 deletions web/ui/src/app/panes/searchSolutions/useSearchController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type SrWorkerClientLike = {
unary: any;
binary: any;
ternary: any;
threadsEnabled: boolean;
}) => void;
start: () => void;
pause: () => void;
Expand Down Expand Up @@ -61,6 +62,7 @@ export function useSearchController(Client: { new (): SrWorkerClientLike }) {
const unaryOps = useSessionStore((s) => s.unaryOps);
const binaryOps = useSessionStore((s) => s.binaryOps);
const ternaryOps = useSessionStore((s) => s.ternaryOps);
const threadsEnabled = useSessionStore((s) => s.threadsEnabled);

const runtime = useSessionStore((s) => s.runtime);
const setRuntime = useSessionStore((s) => s.setRuntime);
Expand Down Expand Up @@ -167,7 +169,7 @@ export function useSearchController(Client: { new (): SrWorkerClientLike }) {
}
});
return () => c.terminate();
}, [Client, setEvalResult, setFront, setRuntime, setSnapshot]);
}, [Client, threadsEnabled, setEvalResult, setFront, setRuntime, setSnapshot]);

const canInit = Boolean(options) && unaryOps.length + binaryOps.length + ternaryOps.length > 0;

Expand All @@ -192,7 +194,8 @@ export function useSearchController(Client: { new (): SrWorkerClientLike }) {
options,
unary: unaryOps,
binary: binaryOps,
ternary: ternaryOps
ternary: ternaryOps,
threadsEnabled
});
};

Expand Down
6 changes: 6 additions & 0 deletions web/ui/src/plotly/Plot.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import createPlotlyComponent from "react-plotly.js/factory";
import Plotly from "plotly.js-basic-dist-min";

const Plot = createPlotlyComponent(Plotly as any);

export default Plot;
8 changes: 8 additions & 0 deletions web/ui/src/state/sessionStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type SessionState = {
csvText: string;
parsed: ParsedDataset | null;

threadsEnabled: boolean;

options: WasmSearchOptions | null;
unaryOps: string[];
binaryOps: string[];
Expand All @@ -59,6 +61,8 @@ type SessionState = {
ensureParsedForRuntime: () => boolean;
setOptionsPatch: (patch: Partial<WasmSearchOptions>) => void;

setThreadsEnabled: (enabled: boolean) => void;

toggleOp: (arity: 1 | 2 | 3, name: string) => void;
applyPreset: (preset: "basic" | "trig" | "explog" | "all") => void;

Expand Down Expand Up @@ -123,6 +127,8 @@ export const useSessionStore = create<SessionState>((set, get) => ({
csvText: DEFAULT_CSV,
parsed: null,

threadsEnabled: false,

options: null,
unaryOps: [],
binaryOps: [],
Expand Down Expand Up @@ -228,6 +234,8 @@ export const useSessionStore = create<SessionState>((set, get) => ({
options: s.options ? { ...s.options, ...patch } : s.options
})),

setThreadsEnabled: (threadsEnabled) => set({ threadsEnabled }),

toggleOp: (arity, name) =>
set((s) => {
if (arity === 1) return { unaryOps: toggleInList(s.unaryOps, name) };
Expand Down
2 changes: 1 addition & 1 deletion web/ui/src/test/setup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ vi.mock("../pkg/symbolic_regression_wasm.js", () => {
});

// Mock Plotly React component (avoid pulling Plotly into jsdom and keep assertions simple).
vi.mock("react-plotly.js", () => {
vi.mock("@/plotly/Plot", () => {
return {
default: (props: any) => {
const xTitle = props?.layout?.xaxis?.title;
Expand Down
9 changes: 9 additions & 0 deletions web/ui/src/worker/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export type WorkerInitMsg = {
unary: string[];
binary: string[];
ternary: string[];
threadsEnabled: boolean;
};

export type WorkerStartMsg = { type: "start" };
Expand All @@ -29,6 +30,13 @@ export type WorkerToWorkerMsg =
| WorkerEvaluateMsg;

export type WorkerReadyMsg = { type: "ready"; split: WasmSplitIndices };
export type WorkerThreadStatusMsg = {
type: "thread_status";
crossOriginIsolated: boolean;
sharedArrayBufferAvailable: boolean;
hasSharedMemory: boolean;
bufferType: string;
};
export type WorkerSnapshotMsg = { type: "snapshot"; snap: SearchSnapshot };
export type WorkerFrontUpdateMsg = { type: "front_update"; front: EquationSummary[] };
export type WorkerEvalResultMsg = { type: "eval_result"; requestId: string; result: WasmEvalResult };
Expand All @@ -39,6 +47,7 @@ export type WorkerErrorMsg = { type: "error"; error: string };

export type WorkerFromWorkerMsg =
| WorkerReadyMsg
| WorkerThreadStatusMsg
| WorkerSnapshotMsg
| WorkerFrontUpdateMsg
| WorkerEvalResultMsg
Expand Down
Loading
Loading