Skip to content

Commit fb10a66

Browse files
vertex-mg-botcopybara-github
authored andcommitted
A fix in the prediction section
PiperOrigin-RevId: 701847844
1 parent 2b2019a commit fb10a66

File tree

1 file changed

+30
-24
lines changed

1 file changed

+30
-24
lines changed

notebooks/community/model_garden/model_garden_mammut.ipynb

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,14 @@
107107
"# @markdown ### Prerequisites\n",
108108
"# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).\n",
109109
"\n",
110-
"# @markdown 2. [Optional] [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. \"us\") is not considered a match for a single region covered by the multi-region range (eg. \"us-central1\"). If not set, a unique GCS bucket will be created instead.\n",
110+
"# @markdown 2. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. \"us\") is not considered a match for a single region covered by the multi-region range (eg. \"us-central1\"). If not set, a unique GCS bucket will be created instead.\n",
111111
"\n",
112112
"! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git\n",
113113
"! pip install -q gradio==4.21.0\n",
114114
"\n",
115115
"import importlib\n",
116116
"import os\n",
117+
"import uuid\n",
117118
"from datetime import datetime\n",
118119
"from typing import Tuple\n",
119120
"\n",
@@ -132,21 +133,31 @@
132133
"# Get the default region for launching jobs.\n",
133134
"REGION = os.environ[\"GOOGLE_CLOUD_REGION\"]\n",
134135
"\n",
136+
"# @markdown 3. If you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus).\n",
137+
"\n",
138+
"# @markdown > | Machine Type | Accelerator Type | Recommended Regions |\n",
139+
"# @markdown | ----------- | ----------- | ----------- |\n",
140+
"# @markdown | a2-ultragpu-1g | 1 NVIDIA_A100_80GB | us-central1, us-east4, europe-west4, asia-southeast1, us-east4 |\n",
141+
"# @markdown | a3-highgpu-2g | 2 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |\n",
142+
"# @markdown | a3-highgpu-4g | 4 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |\n",
143+
"# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | us-central1, us-east5, europe-west4, us-west1, asia-southeast1 |\n",
144+
"\n",
145+
"\n",
135146
"# Cloud Storage bucket for storing the experiment artifacts.\n",
136147
"# A unique GCS bucket will be created for the purpose of this notebook. If you\n",
137148
"# prefer using your own GCS bucket, change the value yourself below.\n",
138149
"now = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
139150
"BUCKET_URI = \"gs://\" # @param {type: \"string\"}\n",
151+
"BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n",
140152
"assert BUCKET_URI.startswith(\"gs://\"), \"BUCKET_URI must start with `gs://`.\"\n",
141153
"\n",
142154
"# Create a unique GCS bucket for this notebook, if not specified by the user.\n",
143155
"assert BUCKET_URI.startswith(\"gs://\"), \"BUCKET_URI must start with `gs://`.\"\n",
144156
"if BUCKET_URI is None or BUCKET_URI.strip() == \"\" or BUCKET_URI == \"gs://\":\n",
145-
" BUCKET_URI = f\"gs://{PROJECT_ID}-tmp-{now}\"\n",
146-
" ! gsutil mb -l {REGION} {BUCKET_URI}\n",
157+
" BUCKET_URI = f\"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}\"\n",
147158
" BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n",
159+
" ! gsutil mb -l {REGION} {BUCKET_URI}\n",
148160
"else:\n",
149-
" BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n",
150161
" shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep \"Location constraint:\" | sed \"s/Location constraint://\"\n",
151162
" bucket_region = shell_output[0].strip().lower()\n",
152163
" if bucket_region != REGION:\n",
@@ -186,23 +197,6 @@
186197
"models, endpoints = {}, {}\n",
187198
"\n",
188199
"\n",
189-
"def resize_image(image: Image.Image, new_width: int = 512) -> Image.Image:\n",
190-
" width, height = image.size\n",
191-
" new_height = int(height * new_width / width)\n",
192-
" new_image = image.resize((new_width, new_height))\n",
193-
" return new_image\n",
194-
"\n",
195-
"\n",
196-
"def load_image(image_url):\n",
197-
" if image_url.startswith(\"gs://\"):\n",
198-
" local_image_path = \"./images/test_image.jpg\"\n",
199-
" common_util.download_gcs_file_to_local(image_url, local_image_path)\n",
200-
" image = common_util.load_img(local_image_path)\n",
201-
" else:\n",
202-
" image = common_util.download_image(image_url)\n",
203-
" return image\n",
204-
"\n",
205-
"\n",
206200
"def deploy_mammut(\n",
207201
" task: str, machine_type: str, accelerator_type: str, accelerator_count: int\n",
208202
") -> Tuple[aiplatform.Model, aiplatform.Endpoint]:\n",
@@ -268,7 +262,7 @@
268262
"):\n",
269263
" \"\"\"Generates predictions based on the input image and text using an Endpoint.\"\"\"\n",
270264
" # Resize and convert image to base64 string.\n",
271-
" resized_image = resize_image(image, new_width)\n",
265+
" resized_image = common_util.resize_image(image, new_width=512)\n",
272266
" instances = [\n",
273267
" {\n",
274268
" \"image_bytes\": {\"b64\": common_util.image_to_base64(resized_image)},\n",
@@ -354,6 +348,17 @@
354348
"# @markdown This can be either a Cloud Storage path (gs://\\<image-path\\>) or a public url (http://\\<image-path\\>)\n",
355349
"image_url = \"https://images.pexels.com/photos/4012966/pexels-photo-4012966.jpeg\" # @param {type:\"string\"}\n",
356350
"\n",
351+
"\n",
352+
"def load_image(image_url):\n",
353+
" if image_url.startswith(\"gs://\"):\n",
354+
" local_image_path = \"./images/test_image.jpg\"\n",
355+
" common_util.download_gcs_file_to_local(image_url, local_image_path)\n",
356+
" image = common_util.load_img(local_image_path)\n",
357+
" else:\n",
358+
" image = common_util.download_image(image_url)\n",
359+
" return image\n",
360+
"\n",
361+
"\n",
357362
"image = load_image(image_url)\n",
358363
"display(image)\n",
359364
"\n",
@@ -608,6 +613,7 @@
608613
"# @markdown This can be either a Cloud Storage path (gs://\\<image-path\\>) or a public url (http://\\<image-path\\>)\n",
609614
"image_url = \"https://images.pexels.com/photos/20427316/pexels-photo-20427316/free-photo-of-a-moped-parked-in-front-of-a-blue-door.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2\" # @param {type:\"string\"}\n",
610615
"\n",
616+
"\n",
611617
"image = load_image(image_url)\n",
612618
"display(image)\n",
613619
"\n",
@@ -683,7 +689,7 @@
683689
"text_embeddings = []\n",
684690
"image_embeddings = []\n",
685691
"for image in images:\n",
686-
" prediction = predict(retrieval_endpoint, image, text)\n",
692+
" prediction = predict(endpoints[\"retrieval\"], image, text)\n",
687693
" image_embeddings.append(np.array(prediction[\"normalized_image_embedding\"]))\n",
688694
" text_embeddings.append(np.array(prediction[\"normalized_text_embedding\"]))\n",
689695
"\n",
@@ -730,7 +736,7 @@
730736
"# Delete Cloud Storage objects that were created.\n",
731737
"delete_bucket = False # @param {type:\"boolean\"}\n",
732738
"if delete_bucket:\n",
733-
" ! gsutil -m rm -r $BUCKET_URI"
739+
" ! gsutil -m rm -r $BUCKET_NAME"
734740
]
735741
}
736742
],

0 commit comments

Comments
 (0)