Skip to content

Commit ff80951

Browse files
torch_vision.py: use set device type (#621)
* use device enum to transfer model to cuda * Update libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent f84d8e2 commit ff80951

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(self, *, engine_args: "EngineArgs"):
106106
assert hasattr(
107107
self.model, "get_image_features"
108108
), f"AutoModel of {engine_args.model_name_or_path} does not have get_image_features method"
109-
if torch.cuda.is_available():
109+
if device == Device.cuda:
110110
self.model = self.model.cuda()
111111
if engine_args.dtype in (Dtype.float16, Dtype.auto):
112112
self.model = self.model.half()

0 commit comments

Comments
 (0)