Skip to content

Commit 483acd7

Browse files
committed
fix: Add error raise for not found model weight in inference. ruff style update
1 parent 9dd7e04 commit 483acd7

File tree

14 files changed

+611
-367
lines changed

14 files changed

+611
-367
lines changed

config.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
data_config = {
2-
"gaze360":
3-
{
4-
"bins": 90,
5-
"binwidth": 4,
6-
"angle": 180 # angle range
7-
},
8-
"mpiigaze":
9-
{
10-
"bins": 28,
11-
"binwidth": 3,
12-
"angle": 42 # angle range
13-
}
14-
15-
}
2+
"gaze360": {
3+
"bins": 90,
4+
"binwidth": 4,
5+
"angle": 180, # angle range
6+
},
7+
"mpiigaze": {
8+
"bins": 28,
9+
"binwidth": 3,
10+
"angle": 42, # angle range
11+
},
12+
}

evaluate.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,40 @@
1313
from utils.helpers import angular_error, gaze_to_3d, get_dataloader, get_model
1414

1515
import warnings
16+
1617
warnings.filterwarnings("ignore")
1718
# Setup logging
18-
logging.basicConfig(level=logging.INFO, format='%(message)s')
19+
logging.basicConfig(level=logging.INFO, format="%(message)s")
1920

2021

2122
def parse_args():
2223
"""Parse input arguments."""
2324
parser = argparse.ArgumentParser(description="Gaze estimation evaluation")
24-
parser.add_argument("--data", type=str, default="data/Gaze360", help="Directory path for gaze images.")
25-
parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name, available `gaze360`, `mpiigaze`")
26-
parser.add_argument("--weight", type=str, default="", help="Path to model weight for evaluation.")
25+
parser.add_argument(
26+
"--data",
27+
type=str,
28+
default="data/Gaze360",
29+
help="Directory path for gaze images.",
30+
)
31+
parser.add_argument(
32+
"--dataset",
33+
type=str,
34+
default="gaze360",
35+
help="Dataset name, available `gaze360`, `mpiigaze`",
36+
)
37+
parser.add_argument(
38+
"--weight", type=str, default="", help="Path to model weight for evaluation."
39+
)
2740
parser.add_argument("--batch-size", type=int, default=64, help="Batch size.")
2841
parser.add_argument(
2942
"--arch",
3043
type=str,
3144
default="resnet18",
32-
help="Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4."
45+
help="Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4.",
46+
)
47+
parser.add_argument(
48+
"--num-workers", type=int, default=8, help="Number of workers for data loading."
3349
)
34-
parser.add_argument("--num-workers", type=int, default=8, help="Number of workers for data loading.")
3550

3651
args = parser.parse_args()
3752

@@ -42,7 +57,9 @@ def parse_args():
4257
args.binwidth = dataset_config["binwidth"]
4358
args.angle = dataset_config["angle"]
4459
else:
45-
raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}")
60+
raise ValueError(
61+
f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}"
62+
)
4663

4764
return args
4865

@@ -63,7 +80,9 @@ def evaluate(params, model, data_loader, idx_tensor, device):
6380
average_error = 0
6481
total_samples = 0
6582

66-
for images, labels_gaze, regression_labels_gaze, _ in tqdm(data_loader, total=len(data_loader)):
83+
for images, labels_gaze, regression_labels_gaze, _ in tqdm(
84+
data_loader, total=len(data_loader)
85+
):
6786
total_samples += regression_labels_gaze.size(0)
6887
images = images.to(device)
6988

@@ -79,8 +98,12 @@ def evaluate(params, model, data_loader, idx_tensor, device):
7998
yaw_predicted = F.softmax(yaw, dim=1)
8099

81100
# Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42)
82-
pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * params.binwidth - params.angle
83-
yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * params.binwidth - params.angle
101+
pitch_predicted = (
102+
torch.sum(pitch_predicted * idx_tensor, 1) * params.binwidth - params.angle
103+
)
104+
yaw_predicted = (
105+
torch.sum(yaw_predicted * idx_tensor, 1) * params.binwidth - params.angle
106+
)
84107

85108
pitch_predicted = np.radians(pitch_predicted.cpu())
86109
yaw_predicted = np.radians(yaw_predicted.cpu())
@@ -91,7 +114,7 @@ def evaluate(params, model, data_loader, idx_tensor, device):
91114
logging.info(
92115
f"Dataset: {params.dataset} | "
93116
f"Total Number of Samples: {total_samples} | "
94-
f"Mean Angular Error: {average_error/total_samples}"
117+
f"Mean Angular Error: {average_error / total_samples}"
95118
)
96119

97120

@@ -104,7 +127,9 @@ def main():
104127
model = get_model(params.arch, params.bins, inference_mode=True)
105128

106129
if os.path.exists(params.weight):
107-
model.load_state_dict(torch.load(params.weight, map_location=device, weights_only=True))
130+
model.load_state_dict(
131+
torch.load(params.weight, map_location=device, weights_only=True)
132+
)
108133
else:
109134
raise ValueError(f"Model weight not found at {params.weight}")
110135

@@ -117,5 +142,5 @@ def main():
117142
evaluate(params, model, test_loader, idx_tensor, device)
118143

119144

120-
if __name__ == '__main__':
145+
if __name__ == "__main__":
121146
main()

inference.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,41 @@
1414
from uniface import RetinaFace
1515

1616
warnings.filterwarnings("ignore")
17-
logging.basicConfig(level=logging.INFO, format='%(message)s')
17+
logging.basicConfig(level=logging.INFO, format="%(message)s")
1818

1919

2020
def parse_args():
2121
parser = argparse.ArgumentParser(description="Gaze estimation inference")
22-
parser.add_argument("--model", type=str, default="resnet34", help="Model name, default `resnet18`")
22+
parser.add_argument(
23+
"--model", type=str, default="resnet34", help="Model name, default `resnet18`"
24+
)
2325
parser.add_argument(
2426
"--weight",
2527
type=str,
2628
default="resnet34.pt",
27-
help="Path to gaze esimation model weights"
29+
help="Path to gaze esimation model weights",
30+
)
31+
parser.add_argument(
32+
"--view",
33+
action="store_true",
34+
default=True,
35+
help="Display the inference results",
36+
)
37+
parser.add_argument(
38+
"--source",
39+
type=str,
40+
default="assets/in_video.mp4",
41+
help="Path to source video file or camera index",
42+
)
43+
parser.add_argument(
44+
"--output", type=str, default="output.mp4", help="Path to save output file"
45+
)
46+
parser.add_argument(
47+
"--dataset",
48+
type=str,
49+
default="gaze360",
50+
help="Dataset name to get dataset related configs",
2851
)
29-
parser.add_argument("--view", action="store_true", default=True, help="Display the inference results")
30-
parser.add_argument("--source", type=str, default="assets/in_video.mp4",
31-
help="Path to source video file or camera index")
32-
parser.add_argument("--output", type=str, default="output.mp4", help="Path to save output file")
33-
parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name to get dataset related configs")
3452
args = parser.parse_args()
3553

3654
# Override default values based on selected dataset
@@ -40,19 +58,23 @@ def parse_args():
4058
args.binwidth = dataset_config["binwidth"]
4159
args.angle = dataset_config["angle"]
4260
else:
43-
raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}")
61+
raise ValueError(
62+
f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}"
63+
)
4464

4565
return args
4666

4767

4868
def pre_process(image):
4969
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50-
transform = transforms.Compose([
51-
transforms.ToPILImage(),
52-
transforms.Resize(448),
53-
transforms.ToTensor(),
54-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55-
])
70+
transform = transforms.Compose(
71+
[
72+
transforms.ToPILImage(),
73+
transforms.Resize(448),
74+
transforms.ToTensor(),
75+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
76+
]
77+
)
5678

5779
image = transform(image)
5880
image_batch = image.unsqueeze(0)
@@ -72,13 +94,16 @@ def main(params):
7294
gaze_detector.load_state_dict(state_dict)
7395
logging.info("Gaze Estimation model weights loaded.")
7496
except Exception as e:
75-
logging.info(f"Exception occured while loading pre-trained weights of gaze estimation model. Exception: {e}")
97+
logging.info(
98+
f"Exception occured while loading pre-trained weights of gaze estimation model. Exception: {e}"
99+
)
100+
raise FileNotFoundError(f"Model weights not found at {params.weight}") from e
76101

77102
gaze_detector.to(device)
78103
gaze_detector.eval()
79104

80105
video_source = params.source
81-
if video_source.isdigit() or video_source == '0':
106+
if video_source.isdigit() or video_source == "0":
82107
cap = cv2.VideoCapture(int(video_source))
83108
else:
84109
cap = cv2.VideoCapture(video_source)
@@ -87,7 +112,9 @@ def main(params):
87112
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
88113
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
89114
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
90-
out = cv2.VideoWriter(params.output, fourcc, cap.get(cv2.CAP_PROP_FPS), (width, height))
115+
out = cv2.VideoWriter(
116+
params.output, fourcc, cap.get(cv2.CAP_PROP_FPS), (width, height)
117+
)
91118

92119
if not cap.isOpened():
93120
raise IOError("Cannot open webcam")
@@ -102,7 +129,7 @@ def main(params):
102129

103130
faces = face_detector.detect(frame)
104131
for face in faces:
105-
bbox = face['bbox']
132+
bbox = face["bbox"]
106133
x_min, y_min, x_max, y_max = map(int, bbox[:4])
107134

108135
image = frame[y_min:y_max, x_min:x_max]
@@ -111,11 +138,20 @@ def main(params):
111138

112139
pitch, yaw = gaze_detector(image)
113140

114-
pitch_predicted, yaw_predicted = F.softmax(pitch, dim=1), F.softmax(yaw, dim=1)
141+
pitch_predicted, yaw_predicted = (
142+
F.softmax(pitch, dim=1),
143+
F.softmax(yaw, dim=1),
144+
)
115145

116146
# Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42)
117-
pitch_predicted = torch.sum(pitch_predicted * idx_tensor, dim=1) * params.binwidth - params.angle
118-
yaw_predicted = torch.sum(yaw_predicted * idx_tensor, dim=1) * params.binwidth - params.angle
147+
pitch_predicted = (
148+
torch.sum(pitch_predicted * idx_tensor, dim=1) * params.binwidth
149+
- params.angle
150+
)
151+
yaw_predicted = (
152+
torch.sum(yaw_predicted * idx_tensor, dim=1) * params.binwidth
153+
- params.angle
154+
)
119155

120156
# Degrees to Radians
121157
pitch_predicted = np.radians(pitch_predicted.cpu())
@@ -128,8 +164,8 @@ def main(params):
128164
out.write(frame)
129165

130166
if params.view:
131-
cv2.imshow('Demo', frame)
132-
if cv2.waitKey(1) & 0xFF == ord('q'):
167+
cv2.imshow("Demo", frame)
168+
if cv2.waitKey(1) & 0xFF == ord("q"):
133169
break
134170

135171
cap.release()

0 commit comments

Comments
 (0)