@@ -72,17 +72,22 @@ def __init__(
7272 f"Model { model_name } not supported. Supported models: { list (MODEL_CONFIGS .keys ())} "
7373 )
7474
75- device = device or torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
76- num_gpus = torch .cuda .device_count ()
77- batch_size = MODEL_CONFIGS [model_name ]["batch_size" ]
78- batch_size = batch_size * num_gpus if num_gpus > 0 else batch_size
75+ # Use the provided device or default to CUDA
76+ self .device = device or torch .device (
77+ "cuda" if torch .cuda .is_available () else "cpu"
78+ )
79+
80+ # Get device ID for logging
81+ self .device_id = self .device .index if hasattr (self .device , "index" ) else 0
7982
83+ # We don't need multi-GPU inside this encoder instance since each instance
84+ # will run on a dedicated GPU
8085 self .cfg = EncoderConfig (
8186 model_name = model_name ,
8287 model_config = MODEL_CONFIGS [model_name ],
83- device = device ,
84- num_gpus = num_gpus ,
85- batch_size = batch_size ,
88+ device = self . device ,
89+ num_gpus = 1 , # Only use 1 GPU per encoder instance
90+ batch_size = MODEL_CONFIGS [ model_name ][ " batch_size" ] ,
8691 use_default_instruction = use_default_instruction ,
8792 use_fp16 = use_fp16 ,
8893 testing_mode = testing_mode ,
@@ -91,7 +96,7 @@ def __init__(
9196 self ._initialize_model ()
9297
9398 def _initialize_model (self ) -> None :
94- """Initialize model."""
99+ """Initialize model on the specific GPU ."""
95100 home_dir = os .path .expanduser ("~" )
96101 model_path = os .path .join (
97102 home_dir , ".cache" , "instructlab" , "models" , self .cfg .model_name
@@ -128,11 +133,9 @@ def _initialize_model(self) -> None:
128133 self .model = self .model .half ()
129134
130135 self .model = self .model .to (self .cfg .device )
136+ logger .info (f"Model loaded on device: { self .cfg .device } " )
131137
132- if self .cfg .num_gpus > 1 :
133- logger .info (f"Using { self .cfg .num_gpus } GPUs" )
134- self .model = torch .nn .DataParallel (self .model )
135-
138+ # No need for DataParallel since we're running one encoder per GPU
136139 self .model .eval ()
137140
138141 def _prepare_inputs (
0 commit comments