-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranscription_chunker.py
165 lines (139 loc) · 6.92 KB
/
transcription_chunker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import logging
import os
from io import BytesIO
from typing import Generator
import webvtt
from langchain_text_splitters import RecursiveCharacterTextSplitter
from .base_chunker import BaseChunker
class TranscriptionChunker(BaseChunker):
"""
`TranscriptionChunker` is a class designed to process and chunk transcription text content, specifically from `WebVTT` (Web Video Text Tracks) format files.
It utilizes the `RecursiveCharacterTextSplitter` to segment the transcription into manageable chunks, considering token limits and content structure.
Initialization:
---------------
The `TranscriptionChunker` is initialized with the following parameters:
- `data` (dict): A dictionary containing the transcription text content to be chunked.
- `max_chunk_size` (int, optional): The maximum size of each chunk in tokens. Defaults to 2048 tokens or the value specified in the
`NUM_TOKENS` environment variable.
- `token_overlap` (int, optional): The number of overlapping tokens between consecutive chunks. Defaults to 100 tokens.
Attributes:
-----------
- `max_chunk_size` (int): The maximum allowed size of each chunk in tokens, derived from the `NUM_TOKENS`
environment variable (default is 2048 tokens).
- `token_overlap` (int): Number of overlapping tokens between chunks.
- `document_content` (str): The content of the document.
Methods:
--------
- `get_chunks()`: Processes the transcription text and generates chunks based on the specified chunking parameters.
It first processes the WebVTT file, extracts the text, and then splits the content into chunks. The method includes
a mechanism to summarize the content and attaches this summary to each chunk.
- `_vtt_process()`: Converts the WebVTT content into a continuous text block, retaining speaker changes. It processes
each caption, merging text from the same speaker and separating segments by speaker changes.
- `_chunk_document_content()`: Splits the processed document content into chunks using the `RecursiveCharacterTextSplitter`.
This method yields each chunk as it is created.
"""
def __init__(
self,
data: dict,
max_chunk_size: int | None = None,
token_overlap: int | None = None,
):
super().__init__(data)
self.max_chunk_size = max_chunk_size or int(os.getenv("NUM_TOKENS", "2048"))
self.token_overlap = token_overlap or 100
def get_chunks(self) -> list[dict]:
"""
Processes the transcription text and generates chunks based on the specified chunking parameters.
Returns:
list[dict]: A list of dictionaries, each representing a chunk of the document.
"""
chunks = []
logging.info(
f"[TranscriptionChunker] [get_chunks] [{self.filename}] Running `get_chunks`."
)
# Extract the text from the vtt file
text = self._vtt_process()
logging.debug(
f"[TranscriptionChunker] [get_chunks] [{self.filename}] Transcription text: {text[:100]}... (first 100 characters)."
)
# Get the summary of the text
prompt = f"Provide clearly elaborated summary along with the keypoints and values mentioned for the transcript of a conversation: {text} "
summary = self.aoai_client.get_completion(prompt)
# Returns a Generator that yields the truncated chunks, chunk sizes when iterated over
text_chunks = self._chunk_content(text)
skipped_chunks = 0
chunk_id = 0
for text_chunk, num_tokens in text_chunks:
# We only keep chunks that are at least above the `minimum_chunk_size` threshold
if num_tokens >= self.max_chunk_size:
chunk_id += 1
chunk_dict = self._create_chunk(
chunk_id=chunk_id,
content=text_chunk,
embedding_text=summary,
summary=summary,
)
chunks.append(chunk_dict)
else:
skipped_chunks += 1
logging.debug(
f"[TranscriptionChunker] [get_chunks] [{self.filename}] {len(chunks)} chunk(s) created."
)
if skipped_chunks > 0:
logging.debug(
f"[TranscriptionChunker] [get_chunks] [{self.filename}] {skipped_chunks} chunk(s) skipped."
)
return chunks
def _vtt_process(self) -> str:
"""
Converts the WebVTT content into a continuous text block, retaining speaker changes.
This method processes the captions from the WebVTT format, concatenating the text while
preserving speaker information. It merges text from the same speaker and separates segments
by speaker changes.
Returns:
str: A string representing the processed transcription text with speaker annotations.
"""
blob_data = self.document_bytes
blob_stream = BytesIO(blob_data)
vtt = webvtt.read_buffer(blob_stream)
data, text, voice = [], "", ""
for caption in vtt:
current_voice = caption.voice or ""
if current_voice != voice:
if text:
data.append(text.replace("\n", " "))
voice, text = current_voice, (
f"{voice}: {caption.text} " if voice else caption.text + " "
)
else:
text += caption.text + " "
if text:
data.append(text.replace("\n", " "))
return "\n".join(data).strip()
def _chunk_content(self, content: str) -> Generator[tuple[str, int], None, None]:
"""
Splits the document content into chunks based on the specified format and criteria.
Args:
content (str): The content to chunk.
Yields:
Generator[tuple[str, int], None, None]:
A tuple containing the chunked content and the number of tokens in the chunk (i.e. the chunk size).
"""
sentence_endings = [".", "!", "?"]
word_breaks = [" ", "\n", "\t"]
splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
separators=sentence_endings + word_breaks,
chunk_size=self.max_chunk_size,
chunk_overlap=self.token_overlap,
)
chunks = splitter.split_text(content)
for chunked_content in chunks:
# Checks the number of tokens
chunk_size = self.token_estimator.estimate_tokens(chunked_content)
if chunk_size > self.max_chunk_size:
logging.warning(
f"[TranscriptionChunker] [_chunk_content] [{self.filename}] Truncating {chunk_size} size chunk to fit within {self.max_chunk_size} tokens."
)
# If number of chunks exceeds `max_chunk_size`, we truncate the chunk until acceptable
chunked_content = self._truncate_chunk(chunked_content)
yield chunked_content, chunk_size