Skip to content

Commit 7a249e9

Browse files
committed
Update ctr_trainer.py
1 parent 0924b1b commit 7a249e9

1 file changed

Lines changed: 35 additions & 55 deletions

File tree

torch_rechub/trainers/ctr_trainer.py

Lines changed: 35 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ class CTRTrainer(object):
2121
earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
2222
device (str): `"cpu"` or `"cuda:0"`
2323
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
24-
loss_mode (bool): whether the model returns prediction only or prediction with extra loss.
25-
``True`` means ``model(x_dict) -> y_pred``.
26-
``False`` means ``model(x_dict) -> (y_pred, other_loss)``.
24+
loss_mode (bool): whether the model returns only prediction or prediction with extra loss
25+
(`True`: `model(x_dict) -> y_pred`, `False`: `model(x_dict) -> (y_pred, other_loss)`).
2726
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
2827
embedding_l1 (float): L1 regularization coefficient for embedding parameters (default=0.0).
2928
embedding_l2 (float): L2 regularization coefficient for embedding parameters (default=0.0).
@@ -248,62 +247,43 @@ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, s
248247
showing layer connections, tensor shapes, and nested module structures.
249248
It automatically extracts feature information from the model.
250249
251-
Parameters
252-
----------
253-
input_data : dict, optional
254-
Example input dict {feature_name: tensor}.
255-
If not provided, dummy inputs will be generated automatically.
256-
batch_size : int, default=2
257-
Batch size for auto-generated dummy input.
258-
seq_length : int, default=10
259-
Sequence length for SequenceFeature.
260-
depth : int, default=3
261-
Visualization depth, higher values show more detail.
262-
Set to -1 to show all layers.
263-
show_shapes : bool, default=True
264-
Whether to display tensor shapes.
265-
expand_nested : bool, default=True
266-
Whether to expand nested modules.
267-
save_path : str, optional
268-
Path to save the graph image (.pdf, .svg, .png).
269-
If None, displays in Jupyter or opens system viewer.
270-
graph_name : str, default="model"
271-
Name for the graph.
272-
device : str, optional
273-
Device for model execution. If None, defaults to 'cpu'.
274-
dpi : int, default=300
275-
Resolution in dots per inch for output image.
276-
Higher values produce sharper images suitable for papers.
277-
**kwargs : dict
278-
Additional arguments passed to torchview.draw_graph().
250+
Args:
251+
input_data (dict, optional): Example input dict {feature_name: tensor}.
252+
If not provided, dummy inputs will be generated automatically.
253+
batch_size (int): Batch size for auto-generated dummy input (default: 2).
254+
seq_length (int): Sequence length for SequenceFeature (default: 10).
255+
depth (int): Visualization depth, higher values show more detail.
256+
Set to -1 to show all layers (default: 3).
257+
show_shapes (bool): Whether to display tensor shapes (default: True).
258+
expand_nested (bool): Whether to expand nested modules (default: True).
259+
save_path (str, optional): Path to save the graph image (.pdf, .svg, .png).
260+
If None, displays in Jupyter or opens system viewer.
261+
graph_name (str): Name for the graph (default: "model").
262+
device (str, optional): Device for model execution. If None, defaults to 'cpu'.
263+
dpi (int): Resolution in dots per inch for output image.
264+
Higher values produce sharper images suitable for papers (default: 300).
265+
**kwargs: Additional arguments passed to ``torchview.draw_graph()``.
279266
280-
Returns
281-
-------
282-
ComputationGraph
283-
A torchview ComputationGraph object.
284-
285-
Raises
286-
------
287-
ImportError
288-
If torchview or graphviz is not installed.
289-
290-
Notes
291-
-----
292-
Default Display Behavior:
293-
When `save_path` is None (default):
267+
Returns:
268+
ComputationGraph: A torchview ComputationGraph object.
269+
270+
Raises:
271+
ImportError: If torchview or graphviz is not installed.
272+
273+
Note:
274+
When ``save_path`` is None (default):
294275
- In Jupyter/IPython: automatically displays the graph inline
295276
- In Python script: opens the graph with system default viewer
296277
297-
Examples
298-
--------
299-
>>> trainer = CTRTrainer(model, ...)
300-
>>> trainer.fit(train_dl, val_dl)
301-
>>>
302-
>>> # Auto-display in Jupyter (no save_path needed)
303-
>>> trainer.visualization(depth=4)
304-
>>>
305-
>>> # Save to high-DPI PNG for papers
306-
>>> trainer.visualization(save_path="model.png", dpi=300)
278+
Example:
279+
>>> trainer = CTRTrainer(model, ...)
280+
>>> trainer.fit(train_dl, val_dl)
281+
>>>
282+
>>> # Auto-display in Jupyter (no save_path needed)
283+
>>> trainer.visualization(depth=4)
284+
>>>
285+
>>> # Save to high-DPI PNG for papers
286+
>>> trainer.visualization(save_path="model.png", dpi=300)
307287
"""
308288
from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
309289

0 commit comments

Comments
 (0)