@@ -21,9 +21,8 @@ class CTRTrainer(object):
2121 earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
2222 device (str): `"cpu"` or `"cuda:0"`
2323 gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
24- loss_mode (bool): whether the model returns prediction only or prediction with extra loss.
25- ``True`` means ``model(x_dict) -> y_pred``.
26- ``False`` means ``model(x_dict) -> (y_pred, other_loss)``.
24+ loss_mode (bool): whether the model returns only prediction or prediction with extra loss
25+ (`True`: `model(x_dict) -> y_pred`, `False`: `model(x_dict) -> (y_pred, other_loss)`).
2726 model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
2827 embedding_l1 (float): L1 regularization coefficient for embedding parameters (default=0.0).
2928 embedding_l2 (float): L2 regularization coefficient for embedding parameters (default=0.0).
@@ -248,62 +247,43 @@ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, s
248247 showing layer connections, tensor shapes, and nested module structures.
249248 It automatically extracts feature information from the model.
250249
251- Parameters
252- ----------
253- input_data : dict, optional
254- Example input dict {feature_name: tensor}.
255- If not provided, dummy inputs will be generated automatically.
256- batch_size : int, default=2
257- Batch size for auto-generated dummy input.
258- seq_length : int, default=10
259- Sequence length for SequenceFeature.
260- depth : int, default=3
261- Visualization depth, higher values show more detail.
262- Set to -1 to show all layers.
263- show_shapes : bool, default=True
264- Whether to display tensor shapes.
265- expand_nested : bool, default=True
266- Whether to expand nested modules.
267- save_path : str, optional
268- Path to save the graph image (.pdf, .svg, .png).
269- If None, displays in Jupyter or opens system viewer.
270- graph_name : str, default="model"
271- Name for the graph.
272- device : str, optional
273- Device for model execution. If None, defaults to 'cpu'.
274- dpi : int, default=300
275- Resolution in dots per inch for output image.
276- Higher values produce sharper images suitable for papers.
277- **kwargs : dict
278- Additional arguments passed to torchview.draw_graph().
250+ Args:
251+ input_data (dict, optional): Example input dict {feature_name: tensor}.
252+ If not provided, dummy inputs will be generated automatically.
253+ batch_size (int): Batch size for auto-generated dummy input (default: 2).
254+ seq_length (int): Sequence length for SequenceFeature (default: 10).
255+ depth (int): Visualization depth, higher values show more detail.
256+ Set to -1 to show all layers (default: 3).
257+ show_shapes (bool): Whether to display tensor shapes (default: True).
258+ expand_nested (bool): Whether to expand nested modules (default: True).
259+ save_path (str, optional): Path to save the graph image (.pdf, .svg, .png).
260+ If None, displays in Jupyter or opens system viewer.
261+ graph_name (str): Name for the graph (default: "model").
262+ device (str, optional): Device for model execution. If None, defaults to 'cpu'.
263+ dpi (int): Resolution in dots per inch for output image.
264+ Higher values produce sharper images suitable for papers (default: 300).
265+ **kwargs: Additional arguments passed to ``torchview.draw_graph()``.
279266
280- Returns
281- -------
282- ComputationGraph
283- A torchview ComputationGraph object.
284-
285- Raises
286- ------
287- ImportError
288- If torchview or graphviz is not installed.
289-
290- Notes
291- -----
292- Default Display Behavior:
293- When `save_path` is None (default):
267+ Returns:
268+ ComputationGraph: A torchview ComputationGraph object.
269+
270+ Raises:
271+ ImportError: If torchview or graphviz is not installed.
272+
273+ Note:
274+ When ``save_path`` is None (default):
294275 - In Jupyter/IPython: automatically displays the graph inline
295276 - In Python script: opens the graph with system default viewer
296277
297- Examples
298- --------
299- >>> trainer = CTRTrainer(model, ...)
300- >>> trainer.fit(train_dl, val_dl)
301- >>>
302- >>> # Auto-display in Jupyter (no save_path needed)
303- >>> trainer.visualization(depth=4)
304- >>>
305- >>> # Save to high-DPI PNG for papers
306- >>> trainer.visualization(save_path="model.png", dpi=300)
278+ Example:
279+ >>> trainer = CTRTrainer(model, ...)
280+ >>> trainer.fit(train_dl, val_dl)
281+ >>>
282+ >>> # Auto-display in Jupyter (no save_path needed)
283+ >>> trainer.visualization(depth=4)
284+ >>>
285+ >>> # Save to high-DPI PNG for papers
286+ >>> trainer.visualization(save_path="model.png", dpi=300)
307287 """
308288 from ..utils .visualization import TORCHVIEW_AVAILABLE , visualize_model
309289
0 commit comments