1+ import os
2+ import uuid
3+ from typing import Any , Optional , Union
4+ import logging
5+
6+ import numpy as np
7+ import torch
8+ import torch_pruning as tp
9+ from fedot .core .data .data import InputData
10+ from fedot .core .operations .operation_parameters import OperationParameters
11+
12+ from lmcompress .architecture .computational .devices import default_device , extract_device
13+ from lmcompress .data .data import CompressionInputData
14+ from lmcompress .models .network_impl .base_nn_model import BaseNeuralModel , BaseNeuralForecaster
15+ from lmcompress .models .network_impl .utils .trainer_factory import create_trainer_from_input_data
16+ from torchinfo import summary
17+ from lmcompress .tools .registry .model_registry import ModelRegistry
18+
19+ DEVICE = default_device ('cuda' )
20+
21+
22+ class BaseCompressionModel :
23+ """Class responsible for NN model implementation.
24+
25+ Attributes:
26+ self.num_features: int, the number of features.
27+
28+ Example:
29+ To use this operation you can create pipeline as follows::
30+ from fedot.core.pipelines.pipeline_builder import PipelineBuilder
31+ from examples.fedot.fedot_ex import init_input_data
32+ from fedot_ind.tools.loader import DataLoader
33+ from fedot_ind.core.repository.initializer_industrial_models import IndustrialModels
34+
35+ train_data, test_data = DataLoader(dataset_name='Ham').load_data()
36+ with IndustrialModels():
37+ pipeline = PipelineBuilder().add_node('resnet_model').add_node('rf').build()
38+ input_data = init_input_data(train_data[0], train_data[1])
39+ pipeline.fit(input_data)
40+ features = pipeline.predict(input_data)
41+ print(features)
42+ """
43+
44+ def __init__ (self , params : Optional [OperationParameters ] = {}):
45+ import logging
46+ logger = logging .getLogger (__name__ )
47+ logger .debug ("BaseCompressionModel.__init__() called" )
48+
49+ # self.epochs = params.get("epochs", 10)
50+ self .batch_size = params .get ("batch_size" , 16 )
51+ self .activation = params .get ("activation" , "ReLU" )
52+ # self.learning_rate = 0.001
53+ self .model = None
54+ # self.model_for_inference = None
55+ # self.optimizer = None
56+ self .params = params
57+
58+ self ._lmcompress_id = params .get ("lmcompress_id" )
59+ if self ._lmcompress_id is None :
60+ self ._lmcompress_id = f"lmcompress_{ uuid .uuid4 ().hex [:8 ]} "
61+
62+ logger .debug (f"lmcompress_id: { self ._lmcompress_id } " )
63+
64+ self ._model_id_before = None
65+ self ._model_id_after = None
66+ self ._model_before_cached = None
67+ self ._model_after_cached = None
68+ self ._registry = ModelRegistry ()
69+
70+ logger .debug (f"BaseCompressionModel initialized with ModelRegistry, auto_cleanup={ self ._registry .auto_cleanup } " )
71+
72+ # def _save_and_clear_cache(self):
73+ # """Save model and clear cache using ModelRegistry.
74+
75+ # Saves the current model to registry and clears memory cache.
76+ # ModelRegistry handles proper cleanup including GPU memory management.
77+ # """
78+ # import logging
79+ # logger = logging.getLogger(__name__)
80+
81+ # if self.model is None:
82+ # logger.debug("No model to save, skipping cache clearing")
83+ # return
84+
85+ # try:
86+ # model_id = self._registry.register_model(
87+ # lmcompress_id=self._lmcompress_id,
88+ # model=self.model,
89+ # stage="cache",
90+ # mode=None,
91+ # delete_model_after_save=True
92+ # )
93+ # self.model = None
94+ # logger.info(f"Model saved to registry and cleared from memory. model_id={model_id}")
95+
96+ # if torch.cuda.is_available():
97+ # torch.cuda.empty_cache()
98+
99+ # except Exception as e:
100+ # logger.error(f"Failed to save model to registry: {e}")
101+ # if self.model is not None:
102+ # self.model.zero_grad()
103+ # del self.model
104+ # self.model = None
105+ # if torch.cuda.is_available():
106+ # torch.cuda.empty_cache()
107+ # logger.warning("Manual cleanup performed due to registry failure")
108+
109+ @property
110+ def model_before (self ):
111+ """Get model_before from cache or registry."""
112+ if self ._model_before_cached is not None :
113+ return self ._model_before_cached
114+
115+ if self ._model_id_before is None :
116+ return None
117+
118+ loaded_model = self ._registry .load_model_from_latest_checkpoint (
119+ self ._lmcompress_id , self ._model_id_before , DEVICE
120+ )
121+
122+ if loaded_model is not None and isinstance (loaded_model , torch .nn .Module ):
123+ self ._model_before_cached = loaded_model
124+ return self ._model_before_cached
125+
126+ return None
127+
128+ @model_before .setter
129+ def model_before (self , value ):
130+ """Set model_before - stores in cache and optionally in registry."""
131+ import logging
132+ logger = logging .getLogger (__name__ )
133+ logger .info (f"model_before setter called: value={ 'not None' if value else 'None' } , _model_id_before={ self ._model_id_before } " )
134+
135+ self ._model_before_cached = value
136+ if value is not None and self ._model_id_before is None :
137+ logger .info ("Registering model_before in ModelRegistry" )
138+ self ._model_id_before = self ._registry .register_model (
139+ lmcompress_id = self ._lmcompress_id ,
140+ model = value ,
141+ stage = "before" ,
142+ mode = None
143+ )
144+ logger .info (f"model_before registered with id={ self ._model_id_before } " )
145+ else :
146+ logger .debug (f"Skipping registration: value is None={ value is None } , already registered={ self ._model_id_before is not None } " )
147+
148+ @property
149+ def model_after (self ):
150+ """Get model_after from cache or registry."""
151+ if self ._model_after_cached is not None :
152+ return self ._model_after_cached
153+
154+ if self ._model_id_after is None :
155+ return None
156+
157+ loaded_model = self ._registry .load_model_from_latest_checkpoint (
158+ self ._lmcompress_id , self ._model_id_after , DEVICE
159+ )
160+
161+ if loaded_model is not None and isinstance (loaded_model , torch .nn .Module ):
162+ self ._model_after_cached = loaded_model
163+ return self ._model_after_cached
164+
165+ return None
166+
167+ @model_after .setter
168+ def model_after (self , value ):
169+ """Set model_after - stores in cache and registers changes."""
170+ import logging
171+ logger = logging .getLogger (__name__ )
172+ logger .info (f"model_after setter called: value={ 'not None' if value else 'None' } , _model_id_after={ self ._model_id_after } " )
173+
174+ self ._model_after_cached = value
175+ if value is not None :
176+ if self ._model_id_after is None :
177+ logger .info ("Registering new model_after in ModelRegistry" )
178+ self ._model_id_after = self ._registry .register_model (
179+ lmcompress_id = self ._lmcompress_id ,
180+ model = value ,
181+ stage = "after" ,
182+ mode = None
183+ )
184+ logger .info (f"model_after registered with id={ self ._model_id_after } " )
185+ else :
186+ logger .info ("Registering changes for model_after" )
187+ self ._registry .register_changes (
188+ lmcompress_id = self ._lmcompress_id ,
189+ model_id = self ._model_id_after ,
190+ model = value ,
191+ stage = "after" ,
192+ mode = None
193+ )
194+ logger .info ("model_after changes registered" )
195+ else :
196+ logger .debug ("Skipping registration: value is None" )
197+
198+ def _save_model_checkpoint (self , model , stage : str ):
199+ """Save model checkpoint to registry.
200+
201+ Args:
202+ model: Model to save
203+ stage: Stage name (e.g., 'before_compression', 'after_compression')
204+ """
205+ model_id = self ._registry .register_model (
206+ lmcompress_id = self ._lmcompress_id ,
207+ model = model ,
208+ stage = stage ,
209+ mode = None
210+ )
211+ return model_id
212+
213+ def _init_model (self , input_data , additional_hooks = tuple ()):
214+ import logging
215+ logger = logging .getLogger (__name__ )
216+ logger .info ("BaseCompressionModel._init_model() started" )
217+
218+ model = input_data .model
219+ logger .info (f"Model type from input_data.target: { type (model ).__name__ } " )
220+
221+ # Support passing a filesystem path to a checkpoint/model at the node input
222+ if isinstance (model , str ):
223+ logger .info (f"Loading model from path: { model } " )
224+ device = default_device ()
225+ loaded = torch .load (model , map_location = device )
226+ if isinstance (loaded , dict ) and "model" in loaded :
227+ model = loaded ["model" ]
228+ else :
229+ model = loaded
230+ logger .info (f"Model loaded: type={ type (model ).__name__ } " )
231+
232+ if not isinstance (model , torch .nn .Module ):
233+ raise ValueError (f"Expected model to be either file path or torch.nn.Module, got { type (model )} " )
234+
235+ logger .info ("Calling model_before setter" )
236+ self .model_before = model
237+ logger .info (f"model_before setter completed, _model_id_before={ self ._model_id_before } " )
238+
239+ # Create trainer using factory
240+ self .trainer = create_trainer_from_input_data (input_data , self .params )
241+ self .trainer .register_additional_hooks (additional_hooks )
242+ self .trainer .model = model
243+
244+ return model
245+
246+ def _fit_model (self , ts : CompressionInputData , split_data : bool = False ):
247+ pass
248+
249+ def _predict_model (self , x_test , output_mode : str = "default" ):
250+ pass
251+
252+ def _get_example_input (self , input_data : InputData ):
253+ batch = next (iter (input_data .val_dataloader ))
254+ if isinstance (batch , (list , tuple )) and len (batch ) == 2 :
255+ return batch [0 ]
256+ return batch
257+
258+ # def finetune(self, finetune_object: callable, finetune_data):
259+ # # TODO del it! 1) finetune may be included into the train loop (just look at LowRank)
260+ # # 2) here the logic is base and need to be more flexible (no extra loss, no scheduler, no different types batch handling)
261+ # self.optimizer = finetune_object.optimizer(
262+ # finetune_object.model.parameters(), lr=finetune_object.learning_rate
263+ # )
264+ # finetune_object.model.train()
265+ # for epoch in range(5): # loop over the dataset multiple times
266+ # running_loss = 0.0
267+ # for i, data in enumerate(finetune_data.features.train_dataloader, 0):
268+ # # get the inputs; data is a list of [inputs, labels]
269+ # inputs, labels = data
270+ # # zero the parameter gradients
271+ # self.optimizer.zero_grad()
272+
273+ # # forward + backward + optimize
274+ # outputs = finetune_object.model(inputs.to(default_device()))
275+ # loss = finetune_object.criterion(outputs, labels.to(default_device()))
276+ # loss.backward()
277+ # self.optimizer.step()
278+
279+ # # print statistics
280+ # running_loss += loss.item()
281+ # if i % 200 == 0: # print every 20000 mini-batches
282+ # print(
283+ # "[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 200)
284+ # )
285+ # running_loss = 0.0
286+ # finetune_object.model.eval()
287+ # return finetune_object
288+
289+ def fit (self , input_data : CompressionInputData ):
290+ """
291+ Method for feature generation for all series
292+ """
293+ self .num_classes = input_data .num_classes
294+ self .task_type = input_data .task
295+ self ._fit_model (input_data )
296+ # self._save_and_clear_cache()
297+ def predict_for_fit (self , input_data : CompressionInputData , output_mode : str = 'lmcompress' ):
298+ return self .predict (input_data , output_mode )
299+ # return self.model_after if output_mode == 'lmcompress' else self.trainer.predict(input_data, output_mode)
300+
301+ def predict (
302+ self , input_data : CompressionInputData , output_mode : str = "lmcompress"
303+ ) -> torch .nn .Module :
304+ if output_mode == 'lmcompress' :
305+ self .trainer .model = self .model_after
306+ else :
307+ self .trainer .model = self .model_before
308+ return self .trainer .predict (input_data , output_mode )
309+
310+ def estimate_params (self , example_batch , model_before , model_after ):
311+ # in future we don't want to store both models simult.
312+ # base_macs, base_nparams = tp.utils.count_ops_and_params(model_before, example_batch)
313+
314+ is_huggingface = (
315+ hasattr (model_before , 'config' ) and
316+ hasattr (model_before .config , 'model_type' ) and
317+ hasattr (model_before , 'base_model' )
318+ )
319+
320+ if is_huggingface :
321+ base_nparams = model_before .num_parameters ()
322+ nparams = model_after .num_parameters ()
323+
324+ base_macs , macs = 0 , 0
325+ else :
326+ base_info = summary (model = model_before , input_data = example_batch .to (extract_device (model_before )), verbose = 0 )
327+ base_macs , base_nparams = base_info .total_mult_adds , base_info .total_params
328+
329+ info = summary (model = model_after , input_data = example_batch .to (extract_device (model_after )), verbose = 0 )
330+ macs , nparams = info .total_mult_adds , info .total_params
331+
332+ return dict (params_before = base_nparams , macs_before = base_macs ,
333+ params_after = nparams , macs_after = macs )
334+
335+ # don't del its for New Year
336+ def _estimate_params (self , model , example_batch ):
337+ base_macs , base_nparams = tp .utils .count_ops_and_params (model , example_batch )
338+ return base_macs , base_nparams
339+
340+ # don't del its for New Year
341+ def _diagnose (self , model , example_batch , * previos_results , annotation = '' ):
342+ logging .info (annotation )
343+ base_macs , base_nparams , * _ = previos_results
344+ macs , nparams = self ._estimate_params (model , example_batch .to (extract_device (model )))
345+ logging .info ("Params: %.2f M => %.2f M" % (base_nparams / 1e6 , nparams / 1e6 ))
346+ logging .info ("MACs: %.2f G => %.2f G" % (base_macs / 1e9 , macs / 1e9 ))
0 commit comments