-
Notifications
You must be signed in to change notification settings - Fork 114
fix: resolve path errors and hard-coded dependencies in semantic segmentation example #371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -162,14 +162,14 @@ def confidence(self, input_output): | |
| return sum_3 | ||
|
|
||
| def sam_predict_ssa(self, image_name, pred): | ||
| with open('/home/hsj/ianvs/project/cache.pickle', 'rb') as file: | ||
| with open('/ianvs/project/cache.pickle', 'rb') as file: | ||
| cache = pickle.load(file) | ||
| img = mmcv.imread(image_name) | ||
| if image_name in cache.keys(): | ||
| mask = cache[image_name] | ||
| print("load cache") | ||
| else: | ||
| sam = sam_model_registry["vit_h"](checkpoint="/home/hsj/ianvs/project/segment-anything/sam_vit_h_4b8939.pth").to('cuda:1') | ||
| sam = sam_model_registry["vit_h"](checkpoint="/ianvs/project/segment-anything/sam_vit_h_4b8939.pth").to('cuda:1') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SAM checkpoint path ( sam_checkpoint = getattr(self.args, "sam_checkpoint", "/ianvs/project/segment-anything/sam_vit_h_4b8939.pth")
device = f"cuda:{self.args.gpu_ids}" if self.args.cuda else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint).to(device) |
||
| mask_branch_model = SamAutomaticMaskGenerator( | ||
| model=sam, | ||
| #points_per_side=64, | ||
|
|
@@ -184,7 +184,7 @@ def sam_predict_ssa(self, image_name, pred): | |
| print('[Model loaded] Mask branch (SAM) is loaded.') | ||
| mask = mask_branch_model.generate(img) | ||
| cache[image_name] = mask | ||
| with open('/home/hsj/ianvs/project/cache.pickle', 'wb') as file: | ||
| with open('/ianvs/project/cache.pickle', 'wb') as file: | ||
| pickle.dump(cache, file) | ||
| print("save cache") | ||
|
|
||
|
|
@@ -234,14 +234,14 @@ def sam_predict_ssa(self, image_name, pred): | |
| return semantc_mask, mask | ||
|
|
||
| def sam_predict(self, image_name, pred): | ||
| with open('/home/hsj/ianvs/project/cache.pickle', 'rb') as file: | ||
| with open('/ianvs/project/cache.pickle', 'rb') as file: | ||
| cache = pickle.load(file) | ||
|
Comment on lines
+237
to
238
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cache file path cache_path = getattr(self.args, "cache_path", "/ianvs/project/cache.pickle")
try:
with open(cache_path, 'rb') as file:
cache = pickle.load(file)
except (FileNotFoundError, EOFError):
cache = {} |
||
| img = mmcv.imread(image_name) | ||
| if image_name in cache.keys(): | ||
| mask = cache[image_name] | ||
| print("load cache") | ||
| else: | ||
| sam = sam_model_registry["vit_h"](checkpoint="/home/hsj/ianvs/project/segment-anything/sam_vit_h_4b8939.pth").to('cuda:1') | ||
| sam = sam_model_registry["vit_h"](checkpoint="/ianvs/project/segment-anything/sam_vit_h_4b8939.pth").to('cuda:1') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SAM checkpoint path and CUDA device are hardcoded here as well, similar to the sam_checkpoint = getattr(self.args, "sam_checkpoint", "/ianvs/project/segment-anything/sam_vit_h_4b8939.pth")
device = f"cuda:{self.args.gpu_ids}" if self.args.cuda else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint).to(device) |
||
| mask_branch_model = SamAutomaticMaskGenerator( | ||
| model=sam, | ||
| #points_per_side=64, | ||
|
|
@@ -256,7 +256,7 @@ def sam_predict(self, image_name, pred): | |
| print('[Model loaded] Mask branch (SAM) is loaded.') | ||
| mask = mask_branch_model.generate(img) | ||
| cache[image_name] = mask | ||
| with open('/home/hsj/ianvs/project/cache.pickle', 'wb') as file: | ||
| with open('/ianvs/project/cache.pickle', 'wb') as file: | ||
| pickle.dump(cache, file) | ||
| print("save cache") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cache file path
/ianvs/project/cache.pickleis hardcoded. This makes the code less flexible and harder to configure for different environments. It would be better to define this path as a constant or pass it as a parameter. Furthermore, the current implementation will raise aFileNotFoundErrorif the cache file doesn't exist (e.g., on a fresh run). It's better to handle this gracefully, for example with atry...exceptblock: