diff --git a/trident/wsi_objects/WSI.py b/trident/wsi_objects/WSI.py index 4cf8ad3..a51de39 100644 --- a/trident/wsi_objects/WSI.py +++ b/trident/wsi_objects/WSI.py @@ -326,7 +326,7 @@ def segment_tissue( for imgs, (xcoords, ycoords) in dataloader: imgs = imgs.to(device, dtype=precision) # Move to device and match dtype - with torch.autocast(device_type=device.split(":")[0], dtype=precision, enabled=(precision != torch.float32)): + with torch.autocast(device_type=str(device).split(":")[0], dtype=precision, enabled=(precision != torch.float32)): preds = segmentation_model(imgs).cpu().numpy() x_starts = np.clip(np.round(xcoords.numpy() * mpp_reduction_factor).astype(int), 0, width - 1) # clip for starts