Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# Standard imports
import os
import gdown

# Third-party imports
from ament_index_python.packages import get_package_share_directory
Expand All @@ -33,7 +34,8 @@ class SPANetContext(ContextAdapter):

def __init__(
self,
checkpoint: str,
checkpoint_url: str,
checkpoint_path: str,
n_features: int = 2048,
gpu_index: int = 0,
) -> None:
Expand Down Expand Up @@ -63,8 +65,23 @@ def __init__(

# Load Checkpoint
ckpt_file = os.path.join(
get_package_share_directory("ada_feeding_action_select"), "data", checkpoint
get_package_share_directory("ada_feeding_action_select"),
"data",
checkpoint_path,
)
if not os.path.exists(ckpt_file):
logger.info(
f"Checkpoint file not found at {ckpt_file}. Downloading from {checkpoint_url}..."
)

try:
gdown.download(checkpoint_url, ckpt_file, quiet=False)
logger.info(f"Checkpoint file downloaded successfully to {ckpt_file}")
except Exception as e:
raise RuntimeError(f"Error downloading checkpoint: {e}")
else:
logger.info(f"Checkpoint file found at {ckpt_file}. Loading...")

ckpt = torch.load(ckpt_file, map_location=self.device)
self.spanet.load_state_dict(ckpt["net"])
self.spanet.eval()
Expand Down
6 changes: 4 additions & 2 deletions ada_feeding_action_select/config/policies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ ada_feeding_action_select:

context_class: ada_feeding_action_select.adapters.SPANetContext
context_kws:
- checkpoint # Relative to share data directory
- checkpoint_url
- checkpoint_path # Relative to share data directory
context_kwargs:
checkpoint: checkpoint/adapter/food_spanet_all_rgb_wall_ckpt_best.pth
checkpoint_url: "https://drive.google.com/uc?id=1BsFe3xyex2_e7MWQEA3Q4oZEzrjzLtiH&export=download" # Direct download link
checkpoint_path: "checkpoint/adapter/food_spanet_all_rgb_wall_ckpt_best.pth" # Local path (relative to share dir)

#context_class: ada_feeding_action_select.adapters.ColorContext

Expand Down