|
49 | 49 | "name": "stderr",
|
50 | 50 | "output_type": "stream",
|
51 | 51 | "text": [
|
52 |
| - "100%|██████████| 49/49 [00:03<00:00, 13.47it/s]\n", |
| 52 | + "100%|██████████| 49/49 [00:03<00:00, 13.41it/s]\n", |
53 | 53 | "/home/georgematheos/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n",
|
54 | 54 | "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n",
|
55 | 55 | " warnings.warn(\n"
|
|
440 | 440 | "metadata": {},
|
441 | 441 | "outputs": [],
|
442 | 442 | "source": [
|
443 |
| - "b3d.rr_init(\"inference_given_gtpose_4\")" |
| 443 | + "b3d.rr_init(\"inference_given_gtpose_5\")" |
444 | 444 | ]
|
445 | 445 | },
|
446 | 446 | {
|
|
454 | 454 | },
|
455 | 455 | {
|
456 | 456 | "cell_type": "code",
|
457 |
| - "execution_count": 22, |
| 457 | + "execution_count": 34, |
458 | 458 | "metadata": {},
|
459 | 459 | "outputs": [],
|
460 | 460 | "source": [
|
461 | 461 | "inference_hyperparams = i.InferenceHyperparams(\n",
|
462 | 462 | " n_poses=1500,\n",
|
463 |
| - " do_stochastic_color_proposals=False,\n", |
| 463 | + " do_stochastic_color_proposals=True,\n", |
464 | 464 | " pose_proposal_std=0.04,\n",
|
465 | 465 | " pose_proposal_conc=1000.,\n",
|
466 | 466 | " prev_color_proposal_laplace_scale=.04,\n",
|
467 |
| - " obs_color_proposal_laplace_scale=.01,\n", |
| 467 | + " obs_color_proposal_laplace_scale=.02,\n", |
468 | 468 | ")"
|
469 | 469 | ]
|
470 | 470 | },
|
471 | 471 | {
|
472 | 472 | "cell_type": "code",
|
473 |
| - "execution_count": 23, |
| 473 | + "execution_count": 35, |
474 | 474 | "metadata": {},
|
475 | 475 | "outputs": [
|
476 | 476 | {
|
|
484 | 484 | "name": "stderr",
|
485 | 485 | "output_type": "stream",
|
486 | 486 | "text": [
|
487 |
| - "/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n", |
488 |
| - " warnings.warn(\n", |
489 |
| - "100%|██████████| 30/30 [00:32<00:00, 1.09s/it]\n" |
| 487 | + "100%|██████████| 30/30 [00:09<00:00, 3.02it/s]\n" |
490 | 488 | ]
|
491 | 489 | }
|
492 | 490 | ],
|
|
539 | 537 | },
|
540 | 538 | {
|
541 | 539 | "cell_type": "code",
|
542 |
| - "execution_count": 25, |
| 540 | + "execution_count": 26, |
543 | 541 | "metadata": {},
|
544 | 542 | "outputs": [],
|
545 | 543 | "source": [
|
546 |
| - "b3d.rr_init(\"real_inference2\")" |
| 544 | + "b3d.rr_init(\"real_inference_3\")" |
547 | 545 | ]
|
548 | 546 | },
|
549 | 547 | {
|
|
562 | 560 | "name": "stderr",
|
563 | 561 | "output_type": "stream",
|
564 | 562 | "text": [
|
565 |
| - "100%|██████████| 20/20 [01:51<00:00, 5.57s/it]\n" |
| 563 | + "/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n", |
| 564 | + " warnings.warn(\n", |
| 565 | + "/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n", |
| 566 | + " warnings.warn(\n", |
| 567 | + "100%|██████████| 20/20 [02:35<00:00, 7.75s/it]\n" |
566 | 568 | ]
|
567 | 569 | }
|
568 | 570 | ],
|
|
593 | 595 | },
|
594 | 596 | {
|
595 | 597 | "cell_type": "code",
|
596 |
| - "execution_count": 28, |
| 598 | + "execution_count": 29, |
597 | 599 | "metadata": {},
|
598 | 600 | "outputs": [
|
599 | 601 | {
|
|
607 | 609 | "name": "stderr",
|
608 | 610 | "output_type": "stream",
|
609 | 611 | "text": [
|
610 |
| - "100%|██████████| 29/29 [02:40<00:00, 5.55s/it]\n" |
| 612 | + " 14%|█▍ | 4/29 [00:40<04:16, 10.25s/it]\n" |
| 613 | + ] |
| 614 | + }, |
| 615 | + { |
| 616 | + "ename": "KeyboardInterrupt", |
| 617 | + "evalue": "", |
| 618 | + "output_type": "error", |
| 619 | + "traceback": [ |
| 620 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 621 | + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", |
| 622 | + "Cell \u001b[0;32mIn[29], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m T \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m20\u001b[39m, \u001b[38;5;28mlen\u001b[39m(all_data))):\n\u001b[1;32m 5\u001b[0m key \u001b[38;5;241m=\u001b[39m b3d\u001b[38;5;241m.\u001b[39msplit_key(key)\n\u001b[0;32m----> 6\u001b[0m trace \u001b[38;5;241m=\u001b[39m \u001b[43mi\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minference_step_c2f\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# number of sequential iterations of the parallel pose proposal to consider\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m5000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# number of poses to propose in parallel\u001b[39;49;00m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# So the total number of poses considered at each step of C2F is 5000 * 1\u001b[39;49;00m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mall_data\u001b[49m\u001b[43m[\u001b[49m\u001b[43mT\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrgbd\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mprev_color_proposal_laplace_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_hyperparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprev_color_proposal_laplace_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mobs_color_proposal_laplace_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_hyperparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mobs_color_proposal_laplace_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mdo_stochastic_color_proposals\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m b3d\u001b[38;5;241m.\u001b[39mchisight\u001b[38;5;241m.\u001b[39mgen3d\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mviz_trace(\n\u001b[1;32m 17\u001b[0m trace,\n\u001b[1;32m 18\u001b[0m T,\n\u001b[1;32m 19\u001b[0m ground_truth_vertices\u001b[38;5;241m=\u001b[39mmeshes[OBJECT_INDEX]\u001b[38;5;241m.\u001b[39mvertices,\n\u001b[1;32m 20\u001b[0m ground_truth_pose\u001b[38;5;241m=\u001b[39mall_data[T][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcamera_pose\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minv() \u001b[38;5;241m@\u001b[39m all_data[T][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobject_poses\u001b[39m\u001b[38;5;124m\"\u001b[39m][OBJECT_INDEX]\n\u001b[1;32m 21\u001b[0m )\n", |
| 623 | + "File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:100\u001b[0m, in \u001b[0;36minference_step_c2f\u001b[0;34m(key, n_seq, n_poses_per_sequential_step, old_trace, observed_rgbd, *args, **kwargs)\u001b[0m\n\u001b[1;32m 98\u001b[0m k1, k2 \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[1;32m 99\u001b[0m trace \u001b[38;5;241m=\u001b[39m advance_time(k1, old_trace, observed_rgbd)\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minfer_latents_c2f\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 101\u001b[0m \u001b[43m \u001b[49m\u001b[43mk2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_seq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_poses_per_sequential_step\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 102\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", |
| 624 | + "File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:121\u001b[0m, in \u001b[0;36minfer_latents_c2f\u001b[0;34m(key, n_seq, n_poses_per_sequential_step, trace, pose_proposal_std_conc_seq, **inference_hyperparam_kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m inference_hyperparams \u001b[38;5;241m=\u001b[39m InferenceHyperparams(\n\u001b[1;32m 115\u001b[0m n_poses\u001b[38;5;241m=\u001b[39mn_poses_per_sequential_step,\n\u001b[1;32m 116\u001b[0m pose_proposal_std\u001b[38;5;241m=\u001b[39mstd,\n\u001b[1;32m 117\u001b[0m pose_proposal_conc\u001b[38;5;241m=\u001b[39mconc,\n\u001b[1;32m 118\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39minference_hyperparam_kwargs,\n\u001b[1;32m 119\u001b[0m )\n\u001b[1;32m 120\u001b[0m key, _ \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[0;32m--> 121\u001b[0m trace, _ \u001b[38;5;241m=\u001b[39m \u001b[43minfer_latents_using_sequential_proposals\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_seq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minference_hyperparams\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trace\n", |
| 625 | + "File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:153\u001b[0m, in \u001b[0;36minfer_latents_using_sequential_proposals\u001b[0;34m(key, n_seq, trace, inference_hyperparams)\u001b[0m\n\u001b[1;32m 151\u001b[0m k1, k2 \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[1;32m 152\u001b[0m ks \u001b[38;5;241m=\u001b[39m split(k1, n_seq)\n\u001b[0;32m--> 153\u001b[0m weights \u001b[38;5;241m=\u001b[39m [\u001b[43mget_weight\u001b[49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m ks]\n\u001b[1;32m 155\u001b[0m normalized_logps \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mlog_softmax(jnp\u001b[38;5;241m.\u001b[39marray(weights))\n\u001b[1;32m 156\u001b[0m chosen_idx \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mcategorical(k2, normalized_logps)\n", |
| 626 | + "File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:149\u001b[0m, in \u001b[0;36minfer_latents_using_sequential_proposals.<locals>.get_weight\u001b[0;34m(key)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_weight\u001b[39m(key):\n\u001b[0;32m--> 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minfer_latents\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mshared_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_trace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_metadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n", |
| 627 | + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " |
611 | 628 | ]
|
612 | 629 | }
|
613 | 630 | ],
|
614 | 631 | "source": [
|
615 | 632 | "## Finish the run\n",
|
| 633 | + "key = jax.random.PRNGKey(1234)\n", |
| 634 | + "trace = trace_20\n", |
616 | 635 | "for T in tqdm(range(20, len(all_data))):\n",
|
617 | 636 | " key = b3d.split_key(key)\n",
|
618 | 637 | " trace = i.inference_step_c2f(\n",
|
|
0 commit comments