diff --git a/pythonCode/med_libs/MEDml/nodes/ModelHandler.py b/pythonCode/med_libs/MEDml/nodes/ModelHandler.py index 66fc4205..f4f5f1a9 100644 --- a/pythonCode/med_libs/MEDml/nodes/ModelHandler.py +++ b/pythonCode/med_libs/MEDml/nodes/ModelHandler.py @@ -25,6 +25,9 @@ def __init__(self, id_: int, global_config_json: json) -> None: """ super().__init__(id_, global_config_json) if self.type == 'train_model': + self.isTuningEnabled = self.config_json['data']['internal']['isTuningEnabled'] + if self.isTuningEnabled: + self.settingsTuning = self.config_json['data']['internal']['settingsTuning'] self.model_id = self.config_json['associated_id'] model_obj = self.global_config_json['nodes'][self.model_id] self.config_json['data']['estimator'] = { @@ -66,7 +69,14 @@ def _execute(self, experiment: dict = None, **kwargs) -> json: "code", f"trained_models = [pycaret_exp.create_model({self.CodeHandler.convert_dict_to_params(settings)})]" ) + if self.isTuningEnabled: + trained_models = [experiment['pycaret_exp'].tune_model(trained_models[0], **self.settingsTuning)] + self.CodeHandler.add_line( + "code", + f"trained_models = [pycaret_exp.tune_model(trained_models[0], {self.CodeHandler.convert_dict_to_params(self.settingsTuning)})]" + ) trained_models_copy = trained_models.copy() + self._info_for_next_node = {'models': trained_models} for model in trained_models_copy: model_copy = copy.deepcopy(model) diff --git a/pythonCode/submodules/MEDimage b/pythonCode/submodules/MEDimage index 11aef7b3..24511d92 160000 --- a/pythonCode/submodules/MEDimage +++ b/pythonCode/submodules/MEDimage @@ -1 +1 @@ -Subproject commit 11aef7b3998694083de48f5274ad6234e3d46f4a +Subproject commit 24511d9217cb62b266b41f73e24fc44601b6bfaf diff --git a/pythonCode/submodules/MEDprofiles b/pythonCode/submodules/MEDprofiles index 23b62a0e..50b20931 160000 --- a/pythonCode/submodules/MEDprofiles +++ b/pythonCode/submodules/MEDprofiles @@ -1 +1 @@ -Subproject commit 23b62a0e610fcfc1389d682d55caf16f27002de0 +Subproject commit 50b2093113a5bd3272157a7160fff68161af250f diff --git a/renderer/components/flow/node.jsx b/renderer/components/flow/node.jsx index c6777e49..32b4bdd9 100644 --- a/renderer/components/flow/node.jsx +++ b/renderer/components/flow/node.jsx @@ -192,12 +192,16 @@ const NodeObject = ({ id, data, nodeSpecific, nodeBody, defaultSettings, onClick />
- {/* here are the default settings of the node. if nothing is specified, nothing is displayed*/} - {defaultSettings} - {/* here are the node specific settings. if nothing is specified, nothing is displayed*/} - {nodeSpecific} - {/* note : quand on va implémenter codeeditor */} - {/* */} +
+ + {/* here are the default settings of the node. if nothing is specified, nothing is displayed*/} + {defaultSettings} + {/* here are the node specific settings. if nothing is specified, nothing is displayed*/} + {nodeSpecific} + {/* note : quand on va implémenter codeeditor */} + {/* */} + +
diff --git a/renderer/components/layout/layoutManager.jsx b/renderer/components/layout/layoutManager.jsx index afb6bfe9..31753dcc 100644 --- a/renderer/components/layout/layoutManager.jsx +++ b/renderer/components/layout/layoutManager.jsx @@ -65,7 +65,6 @@ const LayoutManager = (props) => { toast.error("Go server is not connected !") } ) - } }) } diff --git a/renderer/components/learning/modalSettingsChooser.jsx b/renderer/components/learning/modalSettingsChooser.jsx index 3774bb11..d7d72029 100644 --- a/renderer/components/learning/modalSettingsChooser.jsx +++ b/renderer/components/learning/modalSettingsChooser.jsx @@ -18,8 +18,9 @@ import { FlowFunctionsContext } from "../flow/context/flowFunctionsContext" * This component is used to display a ModalSettingsChooser modal. * it handles the display of the modal and the available options */ -const ModalSettingsChooser = ({ show, onHide, options, id, data }) => { +const ModalSettingsChooser = ({ show, onHide, options, id, data, optionsTuning = null }) => { const [checkedUpdate, setCheckedUpdate] = useState(null) + const [checkedUpdateTuning, setCheckedUpdateTuning] = useState(null) const { updateNode } = useContext(FlowFunctionsContext) // update the node when a setting is checked or unchecked from the modal @@ -38,6 +39,22 @@ const ModalSettingsChooser = ({ show, onHide, options, id, data }) => { } }, [checkedUpdate]) + // update the node when a setting is checked or unchecked from the modal + useEffect(() => { + if (checkedUpdateTuning != null) { + if (checkedUpdateTuning.checked) { + !data.internal.checkedOptionsTuning.includes(checkedUpdateTuning.optionName) && data.internal.checkedOptionsTuning.push(checkedUpdateTuning.optionName) + } else { + data.internal.checkedOptionsTuning = data.internal.checkedOptionsTuning.filter((optionName) => optionName != checkedUpdateTuning.optionName) + delete data.internal.settingsTuning[checkedUpdateTuning.optionName] + } + updateNode({ + id: id, + updatedData: data.internal + }) + } + }, [checkedUpdateTuning]) + return ( // Base modal component built from react-bootstrap @@ -45,10 +62,29 @@ const ModalSettingsChooser = ({ show, onHide, options, id, data }) => { {data.setupParam.title + " options"} {/* Display all the options available for the node */} - + {Object.entries(options).map(([optionName, optionInfos], i) => { - return + return ( + + ) })} + {/* Display all the options available for the tuning */} + {optionsTuning && ( + <> +

Tuning options

+ {Object.entries(optionsTuning).map(([optionName, optionInfos], i) => { + return ( + + ) + })} + + )}
diff --git a/renderer/components/learning/nodesTypes/trainModelNode.jsx b/renderer/components/learning/nodesTypes/trainModelNode.jsx new file mode 100644 index 00000000..329c3169 --- /dev/null +++ b/renderer/components/learning/nodesTypes/trainModelNode.jsx @@ -0,0 +1,191 @@ +import React, { useState, useContext, useEffect } from "react" +import Node from "../../flow/node" +import Input from "../input" +import { Button } from "react-bootstrap" +import ModalSettingsChooser from "../modalSettingsChooser" +import * as Icon from "react-bootstrap-icons" +import { FlowFunctionsContext } from "../../flow/context/flowFunctionsContext" +import { Stack } from "react-bootstrap" +import { Checkbox } from "primereact/checkbox" + +/** + * + * @param {string} id id of the node + * @param {object} data data of the node + * @param {string} type type of the node + * @returns {JSX.Element} A StandardNode node + * + * @description + * This component is used to display a StandardNode node. + * it handles the display of the node and the modal + * + */ +const TrainModelNode = ({ id, data }) => { + const [modalShow, setModalShow] = useState(false) // state of the modal + const { updateNode } = useContext(FlowFunctionsContext) + const [IntegrateTuning, setIntegrateTuning] = useState(data.internal.isTuningEnabled ?? false) + const [modalTuningBody, setModalTuningBody] = useState(null) + + // Check if isTuningEnabled exists in data.internal, if not initialize it + useEffect(() => { + if (!("isTuningEnabled" in data.internal)) { + data.internal.isTuningEnabled = false + updateNode({ + id: id, + updatedData: data.internal + }) + } + console.log(data.internal) + }, []) + + /** + * + * @param {Object} inputUpdate the object containing the name and the value of the input + * @description + * This function is used to update the settings of the node + */ + const onInputChange = (inputUpdate) => { + data.internal.settings[inputUpdate.name] = inputUpdate.value + updateNode({ + id: id, + updatedData: data.internal + }) + } + + /** + * + * @param {Object} inputUpdate the object containing the name and the value of the input + * @description + * This function is used to update the settings of the node + */ + const onInputChangeTuning = (inputUpdate) => { + data.internal.settingsTuning[inputUpdate.name] = inputUpdate.value + updateNode({ + id: id, + updatedData: data.internal + }) + } + + /** + * + * @param {Object} hasWarning an object containing the state of the warning and the tooltip + * @description + * This function is used to handle the warning of the node + */ + const handleWarning = (hasWarning) => { + data.internal.hasWarning = hasWarning + updateNode({ + id: id, + updatedData: data.internal + }) + } + + /** + * + * @param {Object} e the event of the checkbox + * @description + * This function is used to handle the checkbox for enabling the tuning + */ + const handleIntegration = (e) => { + setIntegrateTuning(e.checked) + data.internal.isTuningEnabled = e.checked + updateNode({ + id: id, + updatedData: data.internal + }) + } + + return ( + <> + {/* build on top of the Node component */} + + {"default" in data.setupParam.possibleSettings && ( + <> + + {Object.entries(data.setupParam.possibleSettings.default).map(([settingName, setting]) => { + return ( + + ) + })} + + + )} + + } + // node specific is the body of the node, so optional settings + nodeSpecific={ + <> +
+ handleIntegration(e)} /> + +
+ {/* the button to open the modal (the plus sign)*/} + + {/* the modal component*/} + setModalShow(false)} + options={data.setupParam.possibleSettings.options} + data={data} + id={id} + optionsTuning={data.internal.isTuningEnabled ? data.setupParam.possibleSettingsTuning.options : null} + /> + {/* the inputs for the options */} + {data.internal.checkedOptions.map((optionName) => { + return ( + + ) + })} + {data.internal.isTuningEnabled && data.internal.checkedOptionsTuning && data.internal.checkedOptionsTuning.length > 0 && ( + <> +
+
Tune Model Options
+ {data.internal.checkedOptionsTuning.map((optionName) => { + console.log(data) + return ( + + ) + })} + + )} + + } + // Link to documentation + nodeLink={"https://medomics-udes.gitbook.io/medomicslab-docs/tutorials/development/learning-module"} + /> + + ) +} + +export default TrainModelNode diff --git a/renderer/components/learning/workflow.jsx b/renderer/components/learning/workflow.jsx index 7956e607..ba573b0c 100644 --- a/renderer/components/learning/workflow.jsx +++ b/renderer/components/learning/workflow.jsx @@ -25,6 +25,7 @@ import GroupNode from "../flow/groupNode" import OptimizeIO from "./nodesTypes/optimizeIO" import DatasetNode from "./nodesTypes/datasetNode" import LoadModelNode from "./nodesTypes/loadModelNode" +import TrainModelNode from "./nodesTypes/trainModelNode.jsx" // here are the parameters of the nodes import nodesParams from "../../public/setupVariables/allNodesParams" @@ -83,7 +84,8 @@ const Workflow = ({ setWorkflowType, workflowType }) => { groupNode: GroupNode, optimizeIO: OptimizeIO, datasetNode: DatasetNode, - loadModelNode: LoadModelNode + loadModelNode: LoadModelNode, + trainModelNode: TrainModelNode }), [] ) @@ -145,6 +147,12 @@ const Workflow = ({ setWorkflowType, workflowType }) => { if (!node.id.includes("opt")) { let subworkflowType = node.data.internal.subflowId != "MAIN" ? "optimize" : "learning" node.data.setupParam.possibleSettings = deepCopy(staticNodesParams[subworkflowType][node.data.internal.type]["possibleSettings"][MLType]) + console.log(node.type) + if (node.type == "trainModelNode") { + node.data.setupParam.possibleSettingsTuning = deepCopy(staticNodesParams["optimize"]["tune_model"]["possibleSettings"][MLType]) + node.data.internal.checkedOptionsTuning = [] + node.data.internal.settingsTuning = {} + } node.data.internal.settings = {} node.data.internal.checkedOptions = [] if (node.type == "selectionNode") { @@ -402,6 +410,11 @@ const Workflow = ({ setWorkflowType, workflowType }) => { let subworkflowType = node.data.internal.subflowId != "MAIN" ? "optimize" : "learning" let setupParams = deepCopy(staticNodesParams[subworkflowType][node.data.internal.type]) setupParams.possibleSettings = setupParams["possibleSettings"][newScene.MLType] + console.log(node.type) + if (node.type == "trainModelNode") { + let setupParamsTuning = deepCopy(staticNodesParams["optimize"]["tune_model"]) + setupParams.possibleSettingsTuning = setupParamsTuning["possibleSettings"][newScene.MLType] + } node.data.setupParam = setupParams } }) @@ -461,6 +474,12 @@ const Workflow = ({ setWorkflowType, workflowType }) => { if (!newNode.id.includes("opt")) { setupParams = deepCopy(staticNodesParams[workflowType][newNode.data.internal.type]) setupParams.possibleSettings = setupParams["possibleSettings"][MLType] + if (newNode.type == "trainModelNode") { + let setupParamsTuning = deepCopy(staticNodesParams["optimize"]["tune_model"]) + setupParams.possibleSettingsTuning = setupParamsTuning["possibleSettings"][MLType] + newNode.data.internal.checkedOptionsTuning = [] + newNode.data.internal.settingsTuning = {} + } } newNode.id = `${newNode.id}${associatedNode ? `.${associatedNode}` : ""}` // if the node is a sub-group node, it has the id of the parent node seperated by a dot. useful when processing only ids newNode.hidden = newNode.type == "optimizeIO" @@ -481,6 +500,7 @@ const Workflow = ({ setWorkflowType, workflowType }) => { newNode.data.internal.selection = newNode.type == "selectionNode" && Object.keys(setupParams.possibleSettings)[0] newNode.data.internal.checkedOptions = [] + newNode.data.internal.subflowId = !associatedNode ? groupNodeId.id : associatedNode newNode.data.internal.hasWarning = { state: false } @@ -620,6 +640,7 @@ const Workflow = ({ setWorkflowType, workflowType }) => { const plotDirectoryID = await insertMEDDataObjectIfNotExists(plotsDirectory) // Clean everything before running a new experiment + console.log("sending flow ", flow) let { success, isValid } = await cleanJson2Send(flow, up2Id, plotDirectoryID) if (success) { requestBackendRunExperiment(port, backendMetadataFileID, isValid) @@ -778,10 +799,11 @@ const Workflow = ({ setWorkflowType, workflowType }) => { if (reactFlowInstance && metadataFileID) { const flow = deepCopy(reactFlowInstance.toObject()) flow.MLType = MLType + flow.intersections = intersections + console.log("scene saved", flow) flow.nodes.forEach((node) => { node.data.setupParam = null }) - flow.intersections = intersections let success = await overwriteMEDDataObjectContent(metadataFileID, [flow]) if (success) { toast.success("Scene has been saved successfully") diff --git a/renderer/components/mainPages/dataComponents/wsSelectMultiple.jsx b/renderer/components/mainPages/dataComponents/wsSelectMultiple.jsx index aad7e038..220f8db7 100644 --- a/renderer/components/mainPages/dataComponents/wsSelectMultiple.jsx +++ b/renderer/components/mainPages/dataComponents/wsSelectMultiple.jsx @@ -35,78 +35,81 @@ const WsSelectMultiple = ({ const processData = async () => { if (globalData !== undefined) { let ids = Object.keys(globalData) - let datasetListToShow = await Promise.all(ids.map(async (id) => { - // Only process files in the selected root directory - if (rootDir != undefined) { - if (globalData[globalData[id].parentID]) { - if (rootDir.includes(globalData[globalData[id].parentID].name) || rootDir.includes(globalData[globalData[id].parentID].originalName)) { - if (!(!acceptFolder && globalData[id].type == "directory")) { - if (acceptedExtensions.includes("all") || acceptedExtensions.includes(globalData[id].type)) { - if (!matchRegex || matchRegex.test(globalData[id].name)) { - // Initializations - let columnsTags = {} - let tags = [] - let tagsCollections = await getCollectionTags(id) // Get the tags of the file from db - tagsCollections = await tagsCollections.toArray() // Convert to array - // Process the tags and link them to columns: {column_name: [tags]} - tagsCollections.map((tagCollection) => { - let tempColName = tagCollection.column_name - if (tagCollection.column_name.includes("_|_")) { - tempColName = tagCollection.column_name.split("_|_")[1] + let datasetListToShow = await Promise.all( + ids.map(async (id) => { + // Only process files in the selected root directory + if (rootDir != undefined) { + if (globalData[globalData[id].parentID]) { + if (rootDir.includes(globalData[globalData[id].parentID].name) || rootDir.includes(globalData[globalData[id].parentID].originalName)) { + if (!(!acceptFolder && globalData[id].type == "directory")) { + if (acceptedExtensions.includes("all") || acceptedExtensions.includes(globalData[id].type)) { + if (!matchRegex || matchRegex.test(globalData[id].name)) { + // Initializations + let columnsTags = {} + let tags = [] + let tagsCollections = await getCollectionTags(id) // Get the tags of the file from db + tagsCollections = await tagsCollections.toArray() // Convert to array + // Process the tags and link them to columns: {column_name: [tags]} + tagsCollections.map((tagCollection) => { + let tempColName = tagCollection.column_name + if (tagCollection.column_name.includes("_|_")) { + tempColName = tagCollection.column_name.split("_|_")[1] + } + columnsTags[tempColName] = tagCollection.tags + tags = tags.concat(tagCollection.tags) + }) + tags = [...new Set(tags)] // Remove duplicates + + // Add the file to the list + return { + key: id, + id: id, + name: globalData[id].name, + tags: tags, + columnsTags: columnsTags } - columnsTags[tempColName] = tagCollection.tags - tags = tags.concat(tagCollection.tags) - }) - tags = [...new Set(tags)] // Remove duplicates - - // Add the file to the list - return { - id: id, - name: globalData[id].name, - tags: tags, - columnsTags: columnsTags } } } } } - } - // else, we want to add any file (or folder) from acceptedExtensions - } else { - if (acceptedExtensions.includes(globalData[id].extension) || acceptedExtensions.includes("all")) { - let columnsTags = {} - let tags = [] - let tagsCollections = await getCollectionTags(id) - tagsCollections = await tagsCollections.toArray() - tagsCollections.map((tagCollection) => { - columnsTags[tagCollection.column_name] = tagCollection.tags - tags = tags.concat(tagCollection.tags) - }) - return { - id: id, - name: globalData[id].name, - tags: tags, - columnsTags: columnsTags + // else, we want to add any file (or folder) from acceptedExtensions + } else { + if (acceptedExtensions.includes(globalData[id].extension) || acceptedExtensions.includes("all")) { + let columnsTags = {} + let tags = [] + let tagsCollections = await getCollectionTags(id) + tagsCollections = await tagsCollections.toArray() + tagsCollections.map((tagCollection) => { + columnsTags[tagCollection.column_name] = tagCollection.tags + tags = tags.concat(tagCollection.tags) + }) + return { + key: id, + id: id, + name: globalData[id].name, + tags: tags, + columnsTags: columnsTags + } } } - } - // Return empty list if the item doesn't meet the conditions - return null; - }) - ) + // Return empty list if the item doesn't meet the conditions + return null + }) + ) - // Filter out any null values - datasetListToShow = datasetListToShow.filter((item) => item !== null); + // Filter out any null values + datasetListToShow = datasetListToShow.filter((item) => item !== null) - console.log("datasetListToShow", datasetListToShow); - setDatasetList(datasetListToShow); + console.log("datasetListToShow", datasetListToShow) + setDatasetList(datasetListToShow) - if (datasetListToShow.length === 0) { - setHasWarning({ state: true, tooltip: "No data file found in the workspace" }); + if (datasetListToShow.length === 0) { + setHasWarning({ state: true, tooltip: "No data file found in the workspace" }) + } } } - } - processData() + processData() }, [globalData]) return ( @@ -116,7 +119,14 @@ const WsSelectMultiple = ({ key={key} disabled={disabled} placeholder={placeholder} - value={Array.isArray(selectedPaths) ? selectedPaths : []} + value={ + Array.isArray(selectedPaths) + ? selectedPaths.map((item) => ({ + ...item, + key: item.key || item.id + })) + : [] + } onChange={(e) => onChange(e.value)} options={datasetList} optionLabel="name" diff --git a/renderer/public/setupVariables/learningNodesParams.jsx b/renderer/public/setupVariables/learningNodesParams.jsx index bc1072f2..74a9e2d6 100644 --- a/renderer/public/setupVariables/learningNodesParams.jsx +++ b/renderer/public/setupVariables/learningNodesParams.jsx @@ -48,7 +48,7 @@ const nodesParams = { possibleSettings: { classification: classificationModelsSettings, regression: regressionModelsSettings } }, train_model: { - type: "standardNode", + type: "trainModelNode", classes: "action create_model run", nbInput: 2, nbOutput: 1, diff --git a/renderer/styles/flow/reactFlow.css b/renderer/styles/flow/reactFlow.css index 160653d5..859182d0 100644 --- a/renderer/styles/flow/reactFlow.css +++ b/renderer/styles/flow/reactFlow.css @@ -494,6 +494,11 @@ input.node-editableLabel { padding: 0.5rem; } +.options-overlayPanel-settingsBody { + max-height: 20rem; + overflow-y: auto; +} + .optimize-io { width: 200px; height: 200px; @@ -644,4 +649,4 @@ path { .modal-settings-chooser .modal-body { max-height: 70vh; overflow-y: auto; -} \ No newline at end of file +}