Skip to content

Commit e1b6ac2

Browse files
committed
fixing black version to be same as workflow
1 parent fd88977 commit e1b6ac2

File tree

9 files changed

+168
-32
lines changed

9 files changed

+168
-32
lines changed

convokit/model/backendMapper.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@ def get_data(
6666

6767
@abstractmethod
6868
def update_data(
69-
self, component_type: str, component_id: str, property_name: str, new_value, index=None,
69+
self,
70+
component_type: str,
71+
component_id: str,
72+
property_name: str,
73+
new_value,
74+
index=None,
7075
):
7176
"""
7277
Set or update the property data for the component of type component_type
@@ -174,7 +179,12 @@ def get_data(
174179
return collection[component_id][property_name]
175180

176181
def update_data(
177-
self, component_type: str, component_id: str, property_name: str, new_value, index=None,
182+
self,
183+
component_type: str,
184+
component_id: str,
185+
property_name: str,
186+
new_value,
187+
index=None,
178188
):
179189
collection = self.get_collection(component_type)
180190
# don't create new collections if the ID is not found; this is supposed to be handled in the
@@ -282,7 +292,12 @@ def get_data(
282292
return result
283293

284294
def update_data(
285-
self, component_type: str, component_id: str, property_name: str, new_value, index=None,
295+
self,
296+
component_type: str,
297+
component_id: str,
298+
property_name: str,
299+
new_value,
300+
index=None,
286301
):
287302
data = self.get_data(component_type, component_id)
288303
if index is not None and index.get(property_name, None) == ["bin"]:

convokit/model/corpus_helpers.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,21 +272,37 @@ def unpack_all_binary_data(
272272
):
273273
# unpack binary data for utterances
274274
unpack_binary_data_for_utts(
275-
utterances, filename, meta_index.utterances_index, exclude_utterance_meta, KeyMeta,
275+
utterances,
276+
filename,
277+
meta_index.utterances_index,
278+
exclude_utterance_meta,
279+
KeyMeta,
276280
)
277281
# unpack binary data for speakers
278282
unpack_binary_data(
279-
filename, speakers_data, meta_index.speakers_index, "speaker", exclude_speaker_meta,
283+
filename,
284+
speakers_data,
285+
meta_index.speakers_index,
286+
"speaker",
287+
exclude_speaker_meta,
280288
)
281289

282290
# unpack binary data for conversations
283291
unpack_binary_data(
284-
filename, convos_data, meta_index.conversations_index, "convo", exclude_conversation_meta,
292+
filename,
293+
convos_data,
294+
meta_index.conversations_index,
295+
"convo",
296+
exclude_conversation_meta,
285297
)
286298

287299
# unpack binary data for overall corpus
288300
unpack_binary_data(
289-
filename, meta, meta_index.overall_index, "overall", exclude_overall_meta,
301+
filename,
302+
meta,
303+
meta_index.overall_index,
304+
"overall",
305+
exclude_overall_meta,
290306
)
291307

292308

convokit/model/speaker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ class Speaker(CorpusComponent):
2424
"""
2525

2626
def __init__(
27-
self, owner=None, id: str = None, utts=None, convos=None, meta: Optional[Dict] = None,
27+
self,
28+
owner=None,
29+
id: str = None,
30+
utts=None,
31+
convos=None,
32+
meta: Optional[Dict] = None,
2833
):
2934
super().__init__(obj_type="speaker", owner=owner, id=id, meta=meta)
3035
self.utterances = utts if utts is not None else dict()

convokit/redirection/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import torch
44

55
DEFAULT_BNB_CONFIG = BitsAndBytesConfig(
6-
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16,
6+
load_in_4bit=True,
7+
bnb_4bit_quant_type="nf4",
8+
bnb_4bit_compute_dtype=torch.bfloat16,
79
)
810

911
DEFAULT_LORA_CONFIG = LoraConfig(

convokit/redirection/gemmaLikelihoodModel.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,16 @@ def _calculate_likelihood_prob(self, past_context, future_context):
9898
future_context = "\n\n".join(future_context)
9999

100100
context_ids = self.tokenizer.encode(
101-
past_context, truncation=True, max_length=self.max_length, return_tensors="pt",
101+
past_context,
102+
truncation=True,
103+
max_length=self.max_length,
104+
return_tensors="pt",
102105
)
103106
future_ids = self.tokenizer.encode(
104-
future_context, truncation=True, max_length=self.max_length, return_tensors="pt",
107+
future_context,
108+
truncation=True,
109+
max_length=self.max_length,
110+
return_tensors="pt",
105111
)
106112
input_ids = torch.cat([context_ids, future_ids], dim=1)
107113
if input_ids.shape[1] > self.max_length:

convokit/redirection/preprocessing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def get_chunk_dataset(tokenizer, convos, max_tokens=512, overlap_tokens=50):
5252
chunks = []
5353
for convo in convos:
5454
convo_chunks = chunk_text_with_overlap(
55-
tokenizer, convo, max_tokens=max_tokens, overlap_tokens=overlap_tokens,
55+
tokenizer,
56+
convo,
57+
max_tokens=max_tokens,
58+
overlap_tokens=overlap_tokens,
5659
)
5760
chunks += convo_chunks
5861

convokit/tests/general/fill_missing_convo_ids/fill_missing_convo_ids_helpers.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,72 @@ def construct_missing_convo_ids_corpus() -> Corpus:
77
# test broken convo where there are multiple conversation_ids
88
corpus = Corpus(
99
utterances=[
10-
Utterance(id="0", reply_to=None, speaker=Speaker(id="alice"), timestamp=0,),
11-
Utterance(id="1", reply_to="0", speaker=Speaker(id="bob"), timestamp=2,),
12-
Utterance(id="2", reply_to="1", speaker=Speaker(id="charlie"), timestamp=1,),
13-
Utterance(id="3", reply_to=None, speaker=Speaker(id="alice2"), timestamp=0,),
10+
Utterance(
11+
id="0",
12+
reply_to=None,
13+
speaker=Speaker(id="alice"),
14+
timestamp=0,
15+
),
16+
Utterance(
17+
id="1",
18+
reply_to="0",
19+
speaker=Speaker(id="bob"),
20+
timestamp=2,
21+
),
22+
Utterance(
23+
id="2",
24+
reply_to="1",
25+
speaker=Speaker(id="charlie"),
26+
timestamp=1,
27+
),
28+
Utterance(
29+
id="3",
30+
reply_to=None,
31+
speaker=Speaker(id="alice2"),
32+
timestamp=0,
33+
),
1434
]
1535
)
1636
return corpus
1737

1838

1939
def get_new_utterances_without_convo_ids() -> List[Utterance]:
2040
return [
21-
Utterance(id="a", reply_to=None, speaker=Speaker(id="alice"), timestamp=0,),
22-
Utterance(id="b", reply_to="a", speaker=Speaker(id="bob"), timestamp=0,),
23-
Utterance(id="c", reply_to=None, speaker=Speaker(id="bob"), timestamp=0,),
41+
Utterance(
42+
id="a",
43+
reply_to=None,
44+
speaker=Speaker(id="alice"),
45+
timestamp=0,
46+
),
47+
Utterance(
48+
id="b",
49+
reply_to="a",
50+
speaker=Speaker(id="bob"),
51+
timestamp=0,
52+
),
53+
Utterance(
54+
id="c",
55+
reply_to=None,
56+
speaker=Speaker(id="bob"),
57+
timestamp=0,
58+
),
2459
]
2560

2661

2762
def get_new_utterances_without_existing_convo_ids():
2863
# i.e. they belong to existing convos
2964
# one responds to root utt, the other responds to leaf utt
3065
return [
31-
Utterance(id="z", reply_to="0", speaker=Speaker(id="alice"), timestamp=0,),
32-
Utterance(id="zz", reply_to="2", speaker=Speaker(id="charlie"), timestamp=0,),
66+
Utterance(
67+
id="z",
68+
reply_to="0",
69+
speaker=Speaker(id="alice"),
70+
timestamp=0,
71+
),
72+
Utterance(
73+
id="zz",
74+
reply_to="2",
75+
speaker=Speaker(id="charlie"),
76+
timestamp=0,
77+
),
3378
]

convokit/tests/general/merge_corpus/merge_corpus_helpers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,16 @@ def construct_non_overlapping_corpus():
4343
def construct_overlapping_corpus():
4444
return Corpus(
4545
utterances=[
46-
Utterance(id="2", text="this is a test", speaker=Speaker(id="charlie"),),
47-
Utterance(id="4", text="this is a sentence", speaker=Speaker(id="echo"),),
46+
Utterance(
47+
id="2",
48+
text="this is a test",
49+
speaker=Speaker(id="charlie"),
50+
),
51+
Utterance(
52+
id="4",
53+
text="this is a sentence",
54+
speaker=Speaker(id="echo"),
55+
),
4856
Utterance(id="5", text="goodbye", speaker=Speaker(id="foxtrot")),
4957
]
5058
)

convokit/tests/general/traverse_convo/traverse_convo_helpers.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,31 +93,67 @@ def construct_tree_corpus():
9393
timestamp=0,
9494
),
9595
Utterance(
96-
id="2", reply_to="0", conversation_id="0", speaker=Speaker(id="alice"), timestamp=2,
96+
id="2",
97+
reply_to="0",
98+
conversation_id="0",
99+
speaker=Speaker(id="alice"),
100+
timestamp=2,
97101
),
98102
Utterance(
99-
id="1", reply_to="0", conversation_id="0", speaker=Speaker(id="alice"), timestamp=1,
103+
id="1",
104+
reply_to="0",
105+
conversation_id="0",
106+
speaker=Speaker(id="alice"),
107+
timestamp=1,
100108
),
101109
Utterance(
102-
id="3", reply_to="0", conversation_id="0", speaker=Speaker(id="alice"), timestamp=3,
110+
id="3",
111+
reply_to="0",
112+
conversation_id="0",
113+
speaker=Speaker(id="alice"),
114+
timestamp=3,
103115
),
104116
Utterance(
105-
id="4", reply_to="1", conversation_id="0", speaker=Speaker(id="alice"), timestamp=4,
117+
id="4",
118+
reply_to="1",
119+
conversation_id="0",
120+
speaker=Speaker(id="alice"),
121+
timestamp=4,
106122
),
107123
Utterance(
108-
id="5", reply_to="1", conversation_id="0", speaker=Speaker(id="alice"), timestamp=5,
124+
id="5",
125+
reply_to="1",
126+
conversation_id="0",
127+
speaker=Speaker(id="alice"),
128+
timestamp=5,
109129
),
110130
Utterance(
111-
id="6", reply_to="1", conversation_id="0", speaker=Speaker(id="alice"), timestamp=6,
131+
id="6",
132+
reply_to="1",
133+
conversation_id="0",
134+
speaker=Speaker(id="alice"),
135+
timestamp=6,
112136
),
113137
Utterance(
114-
id="7", reply_to="2", conversation_id="0", speaker=Speaker(id="alice"), timestamp=4,
138+
id="7",
139+
reply_to="2",
140+
conversation_id="0",
141+
speaker=Speaker(id="alice"),
142+
timestamp=4,
115143
),
116144
Utterance(
117-
id="8", reply_to="2", conversation_id="0", speaker=Speaker(id="alice"), timestamp=5,
145+
id="8",
146+
reply_to="2",
147+
conversation_id="0",
148+
speaker=Speaker(id="alice"),
149+
timestamp=5,
118150
),
119151
Utterance(
120-
id="9", reply_to="3", conversation_id="0", speaker=Speaker(id="alice"), timestamp=4,
152+
id="9",
153+
reply_to="3",
154+
conversation_id="0",
155+
speaker=Speaker(id="alice"),
156+
timestamp=4,
121157
),
122158
Utterance(
123159
id="10",

0 commit comments

Comments
 (0)