|
306 | 306 | "print(\"Creating endpoint.\")\n", |
307 | 307 | "\n", |
308 | 308 | "SERVE_DOCKER_URI = \"us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/jax-timesfm-serve:20240828_1036_RC00\"\n", |
| 309 | + "# @markdown Set use_dedicated_endpoint to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint).\n", |
| 310 | + "use_dedicated_endpoint = True # @param {type:\"boolean\"}\n", |
309 | 311 | "\n", |
310 | 312 | "\n", |
311 | 313 | "def deploy_model(\n", |
|
317 | 319 | " accelerator_type: str = \"NVIDIA_L4\",\n", |
318 | 320 | " accelerator_count: int = 1,\n", |
319 | 321 | " deploy_source: str = \"notebook\",\n", |
| 322 | + " use_dedicated_endpoint: bool = False,\n", |
320 | 323 | ") -> Tuple[aiplatform.Model, aiplatform.Endpoint]:\n", |
321 | 324 | " \"\"\"Creates a Vertex AI Endpoint and deploys TimesFM to the endpoint.\"\"\"\n", |
322 | 325 | " model_name_with_time = common_util.get_job_name_with_datetime(model_name)\n", |
323 | 326 | " endpoint = aiplatform.Endpoint.create(\n", |
324 | 327 | " display_name=f\"{model_name_with_time}-endpoint\",\n", |
325 | 328 | " credentials=aiplatform.initializer.global_config.credentials,\n", |
| 329 | + " dedicated_endpoint_enabled=use_dedicated_endpoint,\n", |
326 | 330 | " )\n", |
327 | 331 | "\n", |
328 | 332 | " if accelerator_type == \"ACCELERATOR_TYPE_UNSPECIFIED\":\n", |
|
374 | 378 | " machine_type=machine_type,\n", |
375 | 379 | " accelerator_type=accelerator_type,\n", |
376 | 380 | " accelerator_count=accelerator_count,\n", |
| 381 | + " use_dedicated_endpoint=use_dedicated_endpoint,\n", |
377 | 382 | ")" |
378 | 383 | ] |
379 | 384 | }, |
|
538 | 543 | "]\n", |
539 | 544 | "\n", |
540 | 545 | "# Query the endpoint.\n", |
541 | | - "results = endpoints[\"timesfm\"].predict(instances=instances)\n", |
| 546 | + "results = endpoints[\"timesfm\"].predict(\n", |
| 547 | + " instances=instances,\n", |
| 548 | + " use_dedicated_endpoint=use_dedicated_endpoint,\n", |
| 549 | + ")\n", |
542 | 550 | "\n", |
543 | 551 | "viz = Visualizer(nrows=1, ncols=3)\n", |
544 | 552 | "viz.visualize_forecast(\n", |
|
616 | 624 | " \"timestamp_format\": \"%Y-%m-%d\",\n", |
617 | 625 | " }\n", |
618 | 626 | " for each_input, each_timestamp in zip(inputs, timestamps)\n", |
619 | | - " ]\n", |
| 627 | + " ],\n", |
| 628 | + " use_dedicated_endpoint=use_dedicated_endpoint,\n", |
620 | 629 | ")\n", |
621 | 630 | "\n", |
622 | 631 | "viz = Visualizer(nrows=1, ncols=3)\n", |
|
788 | 797 | " },\n", |
789 | 798 | "]\n", |
790 | 799 | "\n", |
791 | | - "response = endpoints[\"timesfm\"].predict(instances=cov_instances)\n", |
| 800 | + "response = endpoints[\"timesfm\"].predict(\n", |
| 801 | + " instances=cov_instances,\n", |
| 802 | + " use_dedicated_endpoint=use_dedicated_endpoint,\n", |
| 803 | + ")\n", |
792 | 804 | "\n", |
793 | 805 | "no_cov_instances = [{\"input\": task[\"input\"], \"horizon\": 40} for task in cov_instances]\n", |
794 | | - "no_cov_response = endpoints[\"timesfm\"].predict(instances=no_cov_instances)\n", |
| 806 | + "no_cov_response = endpoints[\"timesfm\"].predict(\n", |
| 807 | + " instances=no_cov_instances,\n", |
| 808 | + " use_dedicated_endpoint=use_dedicated_endpoint,\n", |
| 809 | + ")\n", |
795 | 810 | "\n", |
796 | 811 | "viz = Visualizer(nrows=3, ncols=2)\n", |
797 | 812 | "for task_i, (per_input, per_gt) in enumerate(\n", |
|
0 commit comments