|
96 | 96 | "from etils import epath\n", |
97 | 97 | "from IPython.display import display\n", |
98 | 98 | "import ipywidgets as widgets\n", |
| 99 | + "from ml_collections import config_dict\n", |
99 | 100 | "import numpy as np\n", |
| 101 | + "\n", |
100 | 102 | "from perch_hoplite.agile import colab_utils\n", |
101 | 103 | "from perch_hoplite.agile import embed\n", |
102 | 104 | "from perch_hoplite.agile import source_info\n", |
103 | 105 | "from perch_hoplite.db import brutalism\n", |
104 | | - "from perch_hoplite.db import interface" |
| 106 | + "from perch_hoplite.db import interface\n", |
| 107 | + "from perch_hoplite.zoo import taxonomy_model_tf" |
105 | 108 | ] |
106 | 109 | }, |
107 | 110 | { |
|
207 | 210 | "# @markdown For this example, we use the name of the large audio file, but you can use a different name here.\n", |
208 | 211 | "dataset_name = 'Saipan_A_06_151006_091215' # @param {type:'string'}\n", |
209 | 212 | "# @markdown 2. Input the filepath for the folder that is containing the input audio files.\n", |
210 | | - "dataset_base_path = 'gs://noaa-passive-bioacoustic/pifsc/audio/pipan/saipan/pipan_saipan_06/audio' #@param {type:'string'}\n", |
| 213 | + "dataset_base_path = 'gs://noaa-passive-bioacoustic/pifsc/audio/pipan_10/saipan/pipan_saipan_06/audio' #@param {type:'string'}\n", |
211 | 214 | "# @markdown 3. Input the file pattern for the audio files within that folder that you want to embed. Some examples for how to input:\n", |
212 | 215 | "# @markdown - All files in the base directory of a specific type (not subdirectories): e.g. `*.wav` (or `*.flac` etc) will generate embeddings for all .wav files (or whichever format) in the dataset_base_path\n", |
213 | 216 | "# @markdown - All files in one level of subdirectories within the base directory: `*/*.flac` will generate embeddings for all .flac files\n", |
|
271 | 274 | }, |
272 | 275 | "outputs": [], |
273 | 276 | "source": [ |
274 | | - "#@title Initialize the hoplite database (DB) { vertical-output: true }\n", |
| 277 | + "# @title Initialize the hoplite database (DB) {vertical-output: true}\n", |
| 278 | + "\n", |
275 | 279 | "global db\n", |
276 | 280 | "db = configs.db_config.load_db()\n", |
277 | 281 | "num_embeddings = db.count_embeddings()\n", |
278 | 282 | "\n", |
279 | | - "print('Initialized DB located at ', configs.db_config.db_config.db_path)\n", |
| 283 | + "print('Initialized DB located at:', configs.db_config.db_config.db_path)\n", |
280 | 284 | "\n", |
281 | 285 | "def drop_and_reload_db(_) -> interface.HopliteDBInterface:\n", |
282 | 286 | " db_path = epath.Path(configs.db_config.db_config.db_path)\n", |
|
286 | 290 | " print('\\n Deleted previous db at: ', configs.db_config.db_config.db_path)\n", |
287 | 291 | " db = configs.db_config.load_db()\n", |
288 | 292 | "\n", |
289 | | - "#@markdown If `drop_existing_db` set to True, when the database already exists and contains embeddings,\n", |
290 | | - "#@markdown then those existing embeddings will be erased. You will be prompted to confirm you wish to delete those existing\n", |
291 | | - "#@markdown embeddings. If you want to keep existing embeddings in the database, then set to False, which will append the new\n", |
292 | | - "#@markdown embeddings to the database.\n", |
293 | | - "drop_existing_db = False #@param {type:'boolean'}\n", |
| 293 | + "# @markdown If `drop_existing_db` set to True, when the database already exists and contains\n", |
| 294 | + "# @markdown embeddings, then those existing embeddings will be erased. You will be prompted\n", |
| 295 | + "# @markdown to confirm you wish to delete those existing embeddings. If you want to keep\n", |
| 296 | + "# @markdown existing embeddings in the database, then set to False, which will append the new\n", |
| 297 | + "# @markdown embeddings to the database.\n", |
| 298 | + "drop_existing_db = False # @param {type: 'boolean'}\n", |
294 | 299 | "\n", |
295 | 300 | "if num_embeddings > 0 and drop_existing_db:\n", |
296 | | - " print('Existing DB contains datasets: ', db.get_dataset_names())\n", |
| 301 | + " print('Existing DB contains projects: ', db.get_all_projects())\n", |
297 | 302 | " print('num embeddings: ', num_embeddings)\n", |
298 | 303 | " print('\\n\\nClick the button below to confirm you really want to drop the database at ')\n", |
299 | 304 | " print(f'{configs.db_config.db_config.db_path}\\n')\n", |
|
353 | 358 | "#@title Per dataset statistics { vertical-output: true }\n", |
354 | 359 | "#@markdown This tells us how many unique segments are embedded in the database.\n", |
355 | 360 | "\n", |
356 | | - "for dataset in db.get_dataset_names():\n", |
357 | | - " print(f'\\nDataset \\'{dataset}\\':')\n", |
358 | | - " print('\\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])" |
| 361 | + "\n", |
| 362 | + "# @title Per project statistics {vertical-output: true}\n", |
| 363 | + "\n", |
| 364 | + "for project in db.get_all_projects():\n", |
| 365 | + " window_ids = db.match_window_ids(\n", |
| 366 | + " deployments_filter=config_dict.create(eq=dict(project=project))\n", |
| 367 | + " )\n", |
| 368 | + " print('Project:', project)\n", |
| 369 | + " print('>>> num embeddings:', len(window_ids))\n", |
| 370 | + " print()" |
359 | 371 | ] |
360 | 372 | }, |
361 | 373 | { |
|
367 | 379 | }, |
368 | 380 | "outputs": [], |
369 | 381 | "source": [ |
370 | | - "#@title Show example embedding search\n", |
371 | | - "#@markdown As an example (and to show that the embedding process worked), this\n", |
372 | | - "#@markdown selects a single embedding from the database and outputs the embedding ids of the\n", |
373 | | - "#@markdown top-K (k = 128) nearest neighbors in the database.\n", |
| 382 | + "# @title Show example embedding search\n", |
| 383 | + "# @markdown As an example (and to show that the embedding process worked), this selects a single\n", |
| 384 | + "# @markdown embedding from the database and outputs the embedding ids of the top-k (k = 128)\n", |
| 385 | + "# @markdown nearest neighbors in the database.\n", |
374 | 386 | "\n", |
375 | | - "q = db.get_embedding(db.get_one_embedding_id())\n", |
| 387 | + "q = db.get_embedding(db.match_window_ids(limit=1)[0])\n", |
376 | 388 | "%time results, scores = brutalism.brute_search(worker.db, query_embedding=q, search_list_size=128, score_fn=np.dot)\n", |
377 | | - "print([int(r.embedding_id) for r in results])" |
| 389 | + "print([int(r.window_id) for r in results])" |
378 | 390 | ] |
379 | 391 | }, |
380 | 392 | { |
|
411 | 423 | "from perch_hoplite.db import score_functions\n", |
412 | 424 | "from perch_hoplite.db import search_results\n", |
413 | 425 | "from perch_hoplite.db import sqlite_usearch_impl\n", |
414 | | - "from perch_hoplite.zoo import model_configs" |
| 426 | + "from perch_hoplite.zoo import model_configs\n", |
| 427 | + "from perch_hoplite.zoo import taxonomy_model_tf" |
415 | 428 | ] |
416 | 429 | }, |
417 | 430 | { |
|
438 | 451 | "#@markdown but note that the model sample rates will be different from this rate.\n", |
439 | 452 | "#@markdown If left blank, then the sample rate will be input from the model's\n", |
440 | 453 | "#@markdown sample rate.\n", |
441 | | - "audio_loader_sample_rate_hz = 10_000 #@param {type:'number'}\n", |
| 454 | + "audio_loader_sample_rate_hz = None #@param {type:'number'}\n", |
442 | 455 | "\n", |
443 | | - "db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)\n", |
| 456 | + "db = sqlite_usearch_impl.SQLiteUSearchDB.create(db_path)\n", |
444 | 457 | "db_model_config = db.get_metadata('model_config')\n", |
445 | 458 | "embed_config = db.get_metadata('audio_sources')\n", |
446 | 459 | "model_class = model_configs.get_model_class(db_model_config.model_key)\n", |
|
449 | 462 | "\n", |
450 | 463 | "if audio_loader_sample_rate_hz == None:\n", |
451 | 464 | " audio_loader_sample_rate_hz = embedding_model.sample_rate\n", |
452 | | - "\n", |
453 | 465 | "if hasattr(embedding_model, 'window_size_s'):\n", |
454 | 466 | " window_size_s = embedding_model.window_size_s\n", |
455 | 467 | "else:\n", |
|
479 | 491 | "query_uri = 'gs://bioacoustics-www1/multispecies_blog_media/Be_example3.wav' #@param {type:'string'}\n", |
480 | 492 | "query_label = 'Be_biotwang' #@param {type:'string'}\n", |
481 | 493 | "\n", |
482 | | - "\n", |
483 | 494 | "query = embedding_display.QueryDisplay(\n", |
484 | | - " uri=query_uri, offset_s=0.0, window_size_s=5.0)\n", |
| 495 | + " uri=query_uri,\n", |
| 496 | + " offset_s=0.0,\n", |
| 497 | + " window_size_s=5.0,\n", |
| 498 | + " sample_rate_hz=audio_loader_sample_rate_hz)\n", |
485 | 499 | "_ = query.display_interactive()" |
486 | 500 | ] |
487 | 501 | }, |
|
567 | 581 | "cell_type": "code", |
568 | 582 | "execution_count": null, |
569 | 583 | "metadata": { |
570 | | - "id": "G3sIkOqlXzKB" |
| 584 | + "id": "_Z1zFDksuC05" |
571 | 585 | }, |
572 | 586 | "outputs": [], |
573 | 587 | "source": [ |
574 | 588 | "#@title Save data labels. { vertical-output: true }\n", |
575 | 589 | "#@markdown Counts new labels added to the database.\n", |
| 590 | + "print(\"Annotations before saving new labels:\", len(db.get_all_annotations()))\n", |
| 591 | + "\n", |
| 592 | + "for ann in display_results.harvest_labels(annotator_id):\n", |
| 593 | + " db.insert_annotation(\n", |
| 594 | + " recording_id=ann.recording_id,\n", |
| 595 | + " offsets=ann.offsets,\n", |
| 596 | + " label=ann.label,\n", |
| 597 | + " label_type=ann.label_type,\n", |
| 598 | + " provenance=ann.provenance,\n", |
| 599 | + " skip_duplicates=True,\n", |
| 600 | + " )\n", |
576 | 601 | "\n", |
577 | | - "prev_lbls, new_lbls = 0, 0\n", |
578 | | - "for lbl in display_results.harvest_labels(annotator_id):\n", |
579 | | - " check = db.insert_label(lbl, skip_duplicates=True)\n", |
580 | | - " new_lbls += check\n", |
581 | | - " prev_lbls += (1 - check)\n", |
582 | | - "print('\\nNew labels added: ', new_lbls)\n", |
583 | | - "print('\\nLabeled query results that already existed: ', prev_lbls)" |
| 602 | + "print(\"Annotations after saving new labels:\", len(db.get_all_annotations()))" |
584 | 603 | ] |
585 | 604 | }, |
586 | 605 | { |
587 | 606 | "cell_type": "code", |
588 | 607 | "execution_count": null, |
589 | 608 | "metadata": { |
590 | | - "id": "ouMfqh0KnZS4" |
| 609 | + "id": "1TF6_7DouC05" |
591 | 610 | }, |
592 | 611 | "outputs": [], |
593 | 612 | "source": [ |
594 | 613 | "#@title Check how many labels of each class exist in the data\n", |
595 | | - "print('\\nTotal positive labels per class: ', db.get_class_counts())\n", |
596 | | - "print('\\nTotal negative labels per class: ', db.get_class_counts(label_type = interface.LabelType.NEGATIVE))" |
| 614 | + "print('\\nTotal positive labels per class: ', db.count_each_label(label_type = interface.LabelType.POSITIVE))\n", |
| 615 | + "print('\\nTotal negative labels per class: ', db.count_each_label(label_type = interface.LabelType.NEGATIVE))" |
597 | 616 | ] |
598 | 617 | }, |
599 | 618 | { |
|
740 | 759 | "cell_type": "code", |
741 | 760 | "execution_count": null, |
742 | 761 | "metadata": { |
743 | | - "id": "IMXI3vdfmX48" |
| 762 | + "id": "AluZWMMmwE5K" |
744 | 763 | }, |
745 | 764 | "outputs": [], |
746 | 765 | "source": [ |
747 | 766 | "#@title Save data labels. { vertical-output: true }\n", |
748 | 767 | "#@markdown This will save the labels to the database, attached to the embedded examples.\n", |
749 | 768 | "\n", |
750 | | - "prev_lbls, new_lbls = 0, 0\n", |
751 | | - "for lbl in display_results.harvest_labels(annotator_id):\n", |
752 | | - " check = db.insert_label(lbl, skip_duplicates=True)\n", |
753 | | - " new_lbls += check\n", |
754 | | - " prev_lbls += (1 - check)\n", |
755 | | - "print('\\nNew labels added: ', new_lbls)\n", |
756 | | - "print('\\nQuery examples that already existed: ', prev_lbls)" |
| 769 | + "\n", |
| 770 | + "print(\"Annotations before saving new labels:\", len(db.get_all_annotations()))\n", |
| 771 | + "\n", |
| 772 | + "for ann in display_results.harvest_labels(annotator_id):\n", |
| 773 | + " db.insert_annotation(\n", |
| 774 | + " recording_id=ann.recording_id,\n", |
| 775 | + " offsets=ann.offsets,\n", |
| 776 | + " label=ann.label,\n", |
| 777 | + " label_type=ann.label_type,\n", |
| 778 | + " provenance=ann.provenance,\n", |
| 779 | + " skip_duplicates=True,\n", |
| 780 | + " )\n", |
| 781 | + "\n", |
| 782 | + "print(\"Annotations after saving new labels:\", len(db.get_all_annotations()))" |
757 | 783 | ] |
758 | 784 | }, |
759 | 785 | { |
760 | 786 | "cell_type": "code", |
761 | 787 | "execution_count": null, |
762 | 788 | "metadata": { |
763 | | - "id": "N6jOL17UbgMo" |
| 789 | + "id": "iHKlxrpgwE5K" |
764 | 790 | }, |
765 | 791 | "outputs": [], |
766 | 792 | "source": [ |
767 | 793 | "#@title Check how many labels of each class exist in the data\n", |
768 | | - "print('\\nTotal positive labels per class: ', db.get_class_counts())\n", |
769 | | - "print('\\nTotal negative labels per class: ', db.get_class_counts(label_type = interface.LabelType.NEGATIVE))" |
| 794 | + "print('\\nTotal positive labels per class: ', db.count_each_label(label_type = interface.LabelType.POSITIVE))\n", |
| 795 | + "print('\\nTotal negative labels per class: ', db.count_each_label(label_type = interface.LabelType.NEGATIVE))" |
770 | 796 | ] |
771 | 797 | }, |
772 | 798 | { |
|
884 | 910 | }, |
885 | 911 | "name": "agile_modeling_noaa_demo.ipynb", |
886 | 912 | "private_outputs": true, |
887 | | - "provenance": [ |
888 | | - { |
889 | | - "file_id": "1ePT3-fDB3kA3_T7trthFtu8xTJQWQBoQ", |
890 | | - "timestamp": 1723499538314 |
891 | | - } |
892 | | - ], |
| 913 | + "provenance": [], |
893 | 914 | "toc_visible": true |
894 | 915 | }, |
895 | 916 | "kernelspec": { |
|
0 commit comments