Skip to content

Commit cc6b8f7

Browse files
committed
Fix demo CLI
1 parent 0f17450 commit cc6b8f7

3 files changed

Lines changed: 56 additions & 34 deletions

File tree

kineo/demo/offline/demo.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def main(config_file: str, sequence_name: str, video_paths: list[str], batch_siz
7171
gt_annotations={}
7272
)
7373

74-
75-
if __name__ == "__main__":
74+
def cli():
7675
parser = argparse.ArgumentParser()
7776
parser.add_argument("--config-file", type=str, default="configs/demo/offline/nlf_single_person_sam2.yaml")
7877
parser.add_argument("--sequence-name", type=str, default="offline_demo")
@@ -82,11 +81,16 @@ def main(config_file: str, sequence_name: str, video_paths: list[str], batch_siz
8281
parser.add_argument("--use-cache", action="store_true", default=False)
8382
parser.add_argument("video_paths", type=str, nargs="+")
8483
args = parser.parse_args()
85-
config_file = args.config_file
86-
sequence_name = args.sequence_name
87-
video_paths = args.video_paths
88-
batch_size = args.batch_size
89-
target_fps = args.target_fps
90-
shared_intrinsics = args.shared_intrinsics
91-
use_cache = args.use_cache
92-
main(config_file, sequence_name, video_paths, batch_size, target_fps, shared_intrinsics, use_cache)
84+
85+
main(
86+
args.config_file,
87+
args.sequence_name,
88+
args.video_paths,
89+
args.batch_size,
90+
args.target_fps,
91+
args.shared_intrinsics,
92+
args.use_cache,
93+
)
94+
95+
if __name__ == "__main__":
96+
cli()

kineo/demo/online/demo.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
torch.backends.cudnn.benchmark = True
2222

2323

24-
def create_views_from_temp_videos(temp_videos, cam_names):
24+
def create_views_from_temp_videos(temp_videos, cam_names, device: torch.device):
2525
views = []
2626
for i, temp_video in enumerate(temp_videos):
2727
video_path = temp_video.name
@@ -38,7 +38,7 @@ def create_views_from_temp_videos(temp_videos, cam_names):
3838
)
3939
return views
4040

41-
def create_live_views(cam_indices: list[int], cam_names: list[str]):
41+
def create_live_views(cam_indices: list[int], cam_names: list[str], device: torch.device):
4242
views = []
4343
for cam_idx, cam_name in zip(cam_indices, cam_names):
4444
views.append(
@@ -82,43 +82,39 @@ def load_camera_calibrations(cam_ids: list[str], calibration_output_root_dir: st
8282

8383
return cam_intrinsics, cam_extrinsics, world_reconstructed_scene
8484

85-
if __name__ == "__main__":
8685

87-
parser = argparse.ArgumentParser()
88-
parser.add_argument("--target-fps", type=int, default=20)
89-
parser.add_argument("--target-res", type=str, default="640x480")
90-
parser.add_argument("--live-viz-config", type=str, default="configs/demo/realtime/realtime_viz.yaml")
91-
parser.add_argument("--skip-calibration", action="store_true")
92-
args = parser.parse_args()
93-
target_fps = int(args.target_fps)
94-
target_res = tuple(int(x) for x in args.target_res.split("x"))
86+
def main(target_fps, target_res, live_viz_config, skip_calibration):
87+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9588

9689
calibration_config_file = "configs/demo/realtime/calibration.yaml"
97-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9890

99-
if not args.skip_calibration:
91+
if not skip_calibration:
10092
print("Loading calibration pipeline")
10193
calibration_pipeline = Pipeline.build_pipeline_from_config(calibration_config_file, device)
10294
print("Pipelines loaded")
10395

10496
app = QtWidgets.QApplication(sys.argv)
105-
recorder = MultiCamRecorder(target_fps=target_fps, target_res=target_res, api_preference=cv2.CAP_ANY)
97+
recorder = MultiCamRecorder(
98+
target_fps=target_fps,
99+
target_res=target_res,
100+
api_preference=cv2.CAP_ANY
101+
)
106102
recorder.show()
107103
app.exec_()
108104

109105
cam_indices = recorder.camera_indices
110106
cam_names = recorder.camera_ids
111107

112-
views = create_views_from_temp_videos(recorder.temp_videos, cam_names)
113-
annotations = calibration_pipeline.run(
108+
views = create_views_from_temp_videos(recorder.temp_videos, cam_names, device=device)
109+
calibration_pipeline.run(
114110
sequence_name="calibration",
115111
views=views,
116112
annotations={},
117113
gt_annotations={},
118114
)
119115

120116
print("Loading live viz pipeline")
121-
live_viz_pipeline = Pipeline.build_pipeline_from_config(args.live_viz_config, device)
117+
live_viz_pipeline = Pipeline.build_pipeline_from_config(live_viz_config, device)
122118
print("Live viz pipeline loaded")
123119

124120
camera_indices = []
@@ -132,13 +128,14 @@ def load_camera_calibrations(cam_ids: list[str], calibration_output_root_dir: st
132128
if len(camera_indices) <= 1:
133129
raise Exception(f"Expected at least 2 cameras, got {len(camera_indices)}")
134130

135-
views = create_live_views(camera_indices, camera_ids)
131+
views = create_live_views(camera_indices, camera_ids, device=device)
136132

137-
# Load calibration data
138133
calibration_output_root_dir = "./outputs/realtime_demo_calibration"
139-
cam_intrinsics, cam_extrinsics, world_reconstructed_scene = load_camera_calibrations(camera_ids, calibration_output_root_dir)
134+
cam_intrinsics, cam_extrinsics, world_reconstructed_scene = load_camera_calibrations(
135+
camera_ids, calibration_output_root_dir
136+
)
140137

141-
_ = live_viz_pipeline.run(
138+
live_viz_pipeline.run(
142139
sequence_name="realtime_viz",
143140
views=views,
144141
annotations={
@@ -147,4 +144,25 @@ def load_camera_calibrations(cam_ids: list[str], calibration_output_root_dir: st
147144
"world_reconstructed_scene": world_reconstructed_scene,
148145
},
149146
gt_annotations={},
150-
)
147+
)
148+
149+
150+
def cli():
151+
parser = argparse.ArgumentParser()
152+
parser.add_argument("--target-fps", type=int, default=20)
153+
parser.add_argument("--target-res", type=str, default="640x480")
154+
parser.add_argument("--live-viz-config", type=str, default="configs/demo/realtime/realtime_viz.yaml")
155+
parser.add_argument("--skip-calibration", action="store_true")
156+
args = parser.parse_args()
157+
158+
target_res = tuple(int(x) for x in args.target_res.split("x"))
159+
160+
main(
161+
target_fps=args.target_fps,
162+
target_res=target_res,
163+
live_viz_config=args.live_viz_config,
164+
skip_calibration=args.skip_calibration,
165+
)
166+
167+
if __name__ == "__main__":
168+
cli()

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ dev = [
5353
]
5454

5555
[project.scripts]
56-
kineo-offline = "kineo.demo.offline.demo:main"
57-
kineo-online = "kineo.demo.online.demo:main"
56+
kineo-offline = "kineo.demo.offline.demo:cli"
57+
kineo-online = "kineo.demo.online.demo:cli"
5858

5959
[tool.setuptools.packages.find]
6060
where = ["."]

0 commit comments

Comments
 (0)