Skip to content

Commit d6faa72

Browse files
Inference bug fixes
1 parent b68cab6 commit d6faa72

File tree

4 files changed

+470
-21
lines changed

4 files changed

+470
-21
lines changed

notebooks/bayes3d_paper/tester.ipynb

+33-14
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"name": "stderr",
5050
"output_type": "stream",
5151
"text": [
52-
"100%|██████████| 49/49 [00:03<00:00, 13.47it/s]\n",
52+
"100%|██████████| 49/49 [00:03<00:00, 13.41it/s]\n",
5353
"/home/georgematheos/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n",
5454
"If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n",
5555
" warnings.warn(\n"
@@ -440,7 +440,7 @@
440440
"metadata": {},
441441
"outputs": [],
442442
"source": [
443-
"b3d.rr_init(\"inference_given_gtpose_4\")"
443+
"b3d.rr_init(\"inference_given_gtpose_5\")"
444444
]
445445
},
446446
{
@@ -454,23 +454,23 @@
454454
},
455455
{
456456
"cell_type": "code",
457-
"execution_count": 22,
457+
"execution_count": 34,
458458
"metadata": {},
459459
"outputs": [],
460460
"source": [
461461
"inference_hyperparams = i.InferenceHyperparams(\n",
462462
" n_poses=1500,\n",
463-
" do_stochastic_color_proposals=False,\n",
463+
" do_stochastic_color_proposals=True,\n",
464464
" pose_proposal_std=0.04,\n",
465465
" pose_proposal_conc=1000.,\n",
466466
" prev_color_proposal_laplace_scale=.04,\n",
467-
" obs_color_proposal_laplace_scale=.01,\n",
467+
" obs_color_proposal_laplace_scale=.02,\n",
468468
")"
469469
]
470470
},
471471
{
472472
"cell_type": "code",
473-
"execution_count": 23,
473+
"execution_count": 35,
474474
"metadata": {},
475475
"outputs": [
476476
{
@@ -484,9 +484,7 @@
484484
"name": "stderr",
485485
"output_type": "stream",
486486
"text": [
487-
"/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n",
488-
" warnings.warn(\n",
489-
"100%|██████████| 30/30 [00:32<00:00, 1.09s/it]\n"
487+
"100%|██████████| 30/30 [00:09<00:00, 3.02it/s]\n"
490488
]
491489
}
492490
],
@@ -539,11 +537,11 @@
539537
},
540538
{
541539
"cell_type": "code",
542-
"execution_count": 25,
540+
"execution_count": 26,
543541
"metadata": {},
544542
"outputs": [],
545543
"source": [
546-
"b3d.rr_init(\"real_inference2\")"
544+
"b3d.rr_init(\"real_inference_3\")"
547545
]
548546
},
549547
{
@@ -562,7 +560,11 @@
562560
"name": "stderr",
563561
"output_type": "stream",
564562
"text": [
565-
"100%|██████████| 20/20 [01:51<00:00, 5.57s/it]\n"
563+
"/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n",
564+
" warnings.warn(\n",
565+
"/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n",
566+
" warnings.warn(\n",
567+
"100%|██████████| 20/20 [02:35<00:00, 7.75s/it]\n"
566568
]
567569
}
568570
],
@@ -593,7 +595,7 @@
593595
},
594596
{
595597
"cell_type": "code",
596-
"execution_count": 28,
598+
"execution_count": 29,
597599
"metadata": {},
598600
"outputs": [
599601
{
@@ -607,12 +609,29 @@
607609
"name": "stderr",
608610
"output_type": "stream",
609611
"text": [
610-
"100%|██████████| 29/29 [02:40<00:00, 5.55s/it]\n"
612+
" 14%|█▍ | 4/29 [00:40<04:16, 10.25s/it]\n"
613+
]
614+
},
615+
{
616+
"ename": "KeyboardInterrupt",
617+
"evalue": "",
618+
"output_type": "error",
619+
"traceback": [
620+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
621+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
622+
"Cell \u001b[0;32mIn[29], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m T \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m20\u001b[39m, \u001b[38;5;28mlen\u001b[39m(all_data))):\n\u001b[1;32m 5\u001b[0m key \u001b[38;5;241m=\u001b[39m b3d\u001b[38;5;241m.\u001b[39msplit_key(key)\n\u001b[0;32m----> 6\u001b[0m trace \u001b[38;5;241m=\u001b[39m \u001b[43mi\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minference_step_c2f\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# number of sequential iterations of the parallel pose proposal to consider\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m5000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# number of poses to propose in parallel\u001b[39;49;00m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# So the total number of poses considered at each step of C2F is 5000 * 1\u001b[39;49;00m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mall_data\u001b[49m\u001b[43m[\u001b[49m\u001b[43mT\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrgbd\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mprev_color_proposal_laplace_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_hyperparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprev_color_proposal_laplace_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mobs_color_proposal_laplace_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_hyperparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mobs_color_proposal_laplace_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mdo_stochastic_color_proposals\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m b3d\u001b[38;5;241m.\u001b[39mchisight\u001b[38;5;241m.\u001b[39mgen3d\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mviz_trace(\n\u001b[1;32m 17\u001b[0m trace,\n\u001b[1;32m 18\u001b[0m T,\n\u001b[1;32m 19\u001b[0m ground_truth_vertices\u001b[38;5;241m=\u001b[39mmeshes[OBJECT_INDEX]\u001b[38;5;241m.\u001b[39mvertices,\n\u001b[1;32m 20\u001b[0m ground_truth_pose\u001b[38;5;241m=\u001b[39mall_data[T][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcamera_pose\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minv() \u001b[38;5;241m@\u001b[39m all_data[T][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobject_poses\u001b[39m\u001b[38;5;124m\"\u001b[39m][OBJECT_INDEX]\n\u001b[1;32m 21\u001b[0m )\n",
623+
"File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:100\u001b[0m, in \u001b[0;36minference_step_c2f\u001b[0;34m(key, n_seq, n_poses_per_sequential_step, old_trace, observed_rgbd, *args, **kwargs)\u001b[0m\n\u001b[1;32m 98\u001b[0m k1, k2 \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[1;32m 99\u001b[0m trace \u001b[38;5;241m=\u001b[39m advance_time(k1, old_trace, observed_rgbd)\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minfer_latents_c2f\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 101\u001b[0m \u001b[43m \u001b[49m\u001b[43mk2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_seq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_poses_per_sequential_step\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 102\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
624+
"File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:121\u001b[0m, in \u001b[0;36minfer_latents_c2f\u001b[0;34m(key, n_seq, n_poses_per_sequential_step, trace, pose_proposal_std_conc_seq, **inference_hyperparam_kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m inference_hyperparams \u001b[38;5;241m=\u001b[39m InferenceHyperparams(\n\u001b[1;32m 115\u001b[0m n_poses\u001b[38;5;241m=\u001b[39mn_poses_per_sequential_step,\n\u001b[1;32m 116\u001b[0m pose_proposal_std\u001b[38;5;241m=\u001b[39mstd,\n\u001b[1;32m 117\u001b[0m pose_proposal_conc\u001b[38;5;241m=\u001b[39mconc,\n\u001b[1;32m 118\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39minference_hyperparam_kwargs,\n\u001b[1;32m 119\u001b[0m )\n\u001b[1;32m 120\u001b[0m key, _ \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[0;32m--> 121\u001b[0m trace, _ \u001b[38;5;241m=\u001b[39m \u001b[43minfer_latents_using_sequential_proposals\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_seq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minference_hyperparams\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trace\n",
625+
"File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:153\u001b[0m, in \u001b[0;36minfer_latents_using_sequential_proposals\u001b[0;34m(key, n_seq, trace, inference_hyperparams)\u001b[0m\n\u001b[1;32m 151\u001b[0m k1, k2 \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[1;32m 152\u001b[0m ks \u001b[38;5;241m=\u001b[39m split(k1, n_seq)\n\u001b[0;32m--> 153\u001b[0m weights \u001b[38;5;241m=\u001b[39m [\u001b[43mget_weight\u001b[49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m ks]\n\u001b[1;32m 155\u001b[0m normalized_logps \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mlog_softmax(jnp\u001b[38;5;241m.\u001b[39marray(weights))\n\u001b[1;32m 156\u001b[0m chosen_idx \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mcategorical(k2, normalized_logps)\n",
626+
"File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:149\u001b[0m, in \u001b[0;36minfer_latents_using_sequential_proposals.<locals>.get_weight\u001b[0;34m(key)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_weight\u001b[39m(key):\n\u001b[0;32m--> 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minfer_latents\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mshared_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_trace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_metadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
627+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
611628
]
612629
}
613630
],
614631
"source": [
615632
"## Finish the run\n",
633+
"key = jax.random.PRNGKey(1234)\n",
634+
"trace = trace_20\n",
616635
"for T in tqdm(range(20, len(all_data))):\n",
617636
" key = b3d.split_key(key)\n",
618637
" trace = i.inference_step_c2f(\n",

notebooks/bayes3d_paper/tester2.ipynb

+432
Large diffs are not rendered by default.

src/b3d/chisight/gen3d/image_kernel.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ def logpdf(
131131
)
132132
# Points that don't hit the camera plane should not contribute to the score.
133133
scores = jnp.where(is_unexplained(observed_rgbd_per_point), 0.0, scores)
134+
score_for_pixels_with_points = scores.sum()
134135

135-
# TODO: add scoring for pixels that are not explained by the latent points
136+
# TODO: add scores for pixels that don't get a point
136137

137-
return scores.sum()
138+
return score_for_pixels_with_points
138139

139140
def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
140141
# Note: The distributions were originally defined for per-pixel computation,

src/b3d/chisight/gen3d/inference.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
propose_other_latents_given_pose,
1515
propose_pose,
1616
)
17-
from b3d.chisight.gen3d.model import (
18-
get_hypers,
19-
get_prev_state,
20-
)
17+
from b3d.chisight.gen3d.model import get_hypers, get_new_state
2118

2219

2320
@Pytree.dataclass
@@ -60,7 +57,7 @@ def advance_time(key, trace, observed_rgbd):
6057
U.g(
6158
(
6259
Diff.no_change(get_hypers(trace)),
63-
Diff.unknown_change(get_prev_state(trace)),
60+
Diff.unknown_change(get_new_state(trace)),
6461
),
6562
C.kw(rgbd=observed_rgbd),
6663
),

0 commit comments

Comments
 (0)