Skip to content

Commit a9cf05d

Browse files
committed
Refactor get_weights_argument
1 parent a73266f commit a9cf05d

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

Diff for: python/cog/predictor.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ def get_weights_argument( # pylint: disable=too-many-return-statements
102102
weights_type = get_weights_type(predictor.setup)
103103
if weights_type is None:
104104
return None
105-
weights_url = os.environ.get("COG_WEIGHTS")
106-
weights_path = "weights" # this is the source of a bug isn't it?
107105

108106
# TODO: Cog{File,Path}.validate(...) methods accept either "real"
109107
# paths/files or URLs to those things. In future we can probably tidy this
@@ -113,6 +111,7 @@ def get_weights_argument( # pylint: disable=too-many-return-statements
113111
# this is a breaking change
114112
# previously, CogPath wouldn't be converted in setup(); now it is
115113
# essentially everyone needs to switch from Path to str (or a new URL type)
114+
weights_url = os.environ.get("COG_WEIGHTS")
116115
if weights_url:
117116
if weights_type == CogFile:
118117
return cast(CogFile, CogFile.validate(weights_url))
@@ -122,11 +121,12 @@ def get_weights_argument( # pylint: disable=too-many-return-statements
122121
# allow people to download weights themselves
123122
if weights_type is str:
124123
return weights_url
125-
else:
126-
raise ValueError(
127-
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
128-
)
129-
elif os.path.exists(weights_path):
124+
raise ValueError(
125+
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
126+
)
127+
128+
weights_path = "weights" # this is the source of a bug isn't it?
129+
if os.path.exists(weights_path):
130130
if weights_type == CogFile:
131131
return cast(CogFile, open(weights_path, "rb"))
132132
if weights_type == CogPath:

0 commit comments

Comments
 (0)