Skip to content

Commit ffc00fa

Browse files
committed
Add NER utility.
1 parent df80e95 commit ffc00fa

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

Diff for: README.md

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ python EMOTIONAL-DAMAGE.py 1mjcwuaIJtW_9bGAebM3QK8RltWD9bKrjcr3qgMpivog
5050
python zero-shot-thirty.py --candidate-labels "gender,feminism,politics,religion" "/data/The Tucker Carlson Show/vtt" tucker-zero-shot.csv
5151
```
5252

53+
```shell
54+
python entity-matrix.py "/data/Fresh & Fit/vtt" digfemnet.json 1ZTUTmzyko7hTLsiokXoV-eliUujmazElQ1bET_1234
55+
```
56+
5357
### red-pill-visions
5458

5559
Generate [visualizations](https://ruebot.net/visualizations/mano-whisper/) from the transcripts or summaries of one or more podcasts.

Diff for: red-pill-bottles/entity-matrix.py

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import os
2+
import random
3+
import time
4+
5+
import click
6+
import gspread
7+
import spacy
8+
import webvtt
9+
from alive_progress import alive_bar
10+
from google.oauth2.service_account import Credentials
11+
from gspread.exceptions import APIError
12+
13+
14+
def setup_google_sheets(json_keyfile):
15+
"""Setup function to connect to Google Sheets."""
16+
scope = [
17+
"https://spreadsheets.google.com/feeds",
18+
"https://www.googleapis.com/auth/drive",
19+
]
20+
credentials = Credentials.from_service_account_file(json_keyfile, scopes=scope)
21+
client = gspread.authorize(credentials)
22+
return client
23+
24+
25+
def extract_text_from_vtt(vtt_path):
26+
"""Extract and preprocess text from transcripts."""
27+
return " ".join(caption.text for caption in webvtt.read(vtt_path))
28+
29+
30+
def retry_on_quota_error(func, *args, max_retries=5, base_delay=2, **kwargs):
31+
"""
32+
Retry the given function on quota errors with exponential backoff.
33+
"""
34+
retries = 0
35+
while retries < max_retries:
36+
try:
37+
return func(*args, **kwargs)
38+
except APIError as e:
39+
if "Quota exceeded" in str(e):
40+
retries += 1
41+
delay = base_delay * (2**retries) + random.uniform(0, 1)
42+
print(
43+
f"Quota exceeded. Retrying in {delay:.2f} seconds... (Attempt {retries}/{max_retries})"
44+
)
45+
time.sleep(delay)
46+
else:
47+
raise
48+
raise Exception("Maximum retries reached for Google Sheets API request.")
49+
50+
51+
def process_vtt_files(vtt_directory, json_keyfile, spreadsheet_id):
52+
"""
53+
Extract "PERSON", "NORP", "FAC", "ORG", and "PRODUCT" from transcripts.
54+
"""
55+
56+
# Load spaCy model.
57+
nlp = spacy.load("en_core_web_sm")
58+
59+
# Set up Google Sheets.
60+
client = setup_google_sheets(json_keyfile)
61+
spreadsheet = retry_on_quota_error(client.open_by_key, spreadsheet_id)
62+
worksheet = retry_on_quota_error(spreadsheet.worksheet, "ner")
63+
64+
# Check if headers already exist.
65+
existing_headers = retry_on_quota_error(worksheet.row_values, 1)
66+
headers = ["Filename", "PERSON", "NORP", "FAC", "ORG", "PRODUCT"]
67+
if headers != existing_headers:
68+
retry_on_quota_error(worksheet.insert_row, headers, 1)
69+
70+
# Get list of WebVTT files.
71+
vtt_files = [f for f in os.listdir(vtt_directory) if f.endswith(".vtt")]
72+
73+
# Retrieve filenames already in the worksheet.
74+
existing_filenames = retry_on_quota_error(worksheet.col_values, 1)
75+
76+
# Iterate through transcripts and extract entities.
77+
with alive_bar(len(vtt_files), title="Processing VTT files") as bar:
78+
for filename in vtt_files:
79+
if filename in existing_filenames:
80+
print(f"Skipping {filename}, already processed.")
81+
bar()
82+
continue
83+
84+
vtt_path = os.path.join(vtt_directory, filename)
85+
86+
processed_content = extract_text_from_vtt(vtt_path)
87+
doc = nlp(processed_content)
88+
entities = {"PERSON": [], "NORP": [], "FAC": [], "ORG": [], "PRODUCT": []}
89+
90+
for ent in doc.ents:
91+
if ent.label_ in entities:
92+
entities[ent.label_].append(ent.text)
93+
94+
# Prep data for the Google Sheet.
95+
row = [
96+
filename,
97+
"|".join(set(entities["PERSON"])),
98+
"|".join(set(entities["NORP"])),
99+
"|".join(set(entities["FAC"])),
100+
"|".join(set(entities["ORG"])),
101+
"|".join(set(entities["PRODUCT"])),
102+
]
103+
104+
# Check if the filename already exists in the worksheet.
105+
retry_on_quota_error(worksheet.append_row, row)
106+
time.sleep(1)
107+
bar()
108+
109+
110+
@click.command()
111+
@click.argument("vtt_directory", type=click.Path(exists=True, file_okay=False))
112+
@click.argument("json_keyfile", type=click.Path(exists=True, dir_okay=False))
113+
@click.argument("spreadsheet_id", type=str)
114+
def main(vtt_directory, json_keyfile, spreadsheet_id):
115+
"""
116+
Process a directory of WebVTT files and extract entities using spaCy. Write
117+
the output to a Google Sheet.
118+
"""
119+
process_vtt_files(vtt_directory, json_keyfile, spreadsheet_id)
120+
121+
122+
if __name__ == "__main__":
123+
main()

0 commit comments

Comments
 (0)