Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions speciesnet/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class SpeciesNetClassifier:
MAX_CROP_SIZE = 400

def __init__(
self, model_name: str, target_species_txt: Optional[str] = None
self,
model_name: str,
target_species_txt: Optional[str] = None,
device: Optional[str] = None,
) -> None:
"""Loads the classifier resources.

Expand All @@ -56,19 +59,26 @@ def __init__(
String value identifying the model to be loaded. It can be a Kaggle
identifier (starting with `kaggle:`), a HuggingFace identifier (starting
with `hf:`) or a local folder to load the model from.
device:
Specific device identifier, e.g. "cpu" or "cuda". If None, "cuda"
and "mps" will be used if available.
"""

start_time = time.time()

self.model_info = ModelInfo(model_name)

# Select the best device available.
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
if device is not None:
print(f"Using caller-supplied device {device}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use logging.info() instead of print(), for consistency across the repo? logging.info() should print to stdout/stderr as well.

self.device = device
else:
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"

# Load the model.
self.model = torch.load(
Expand Down
Loading