Skip to content

BERTCGAForcaster Class #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
db403b1
Add Mean H to summarize function
Feb 18, 2025
4291f99
Enable CRAFT to save individual forecasts
Feb 18, 2025
8b87534
Add BERTForCGA
Feb 18, 2025
70146e4
Add Recovery Metric
Feb 18, 2025
4698f51
Configure BERTForecaster
Feb 19, 2025
f367fa3
Initialize Notebook for example
Feb 19, 2025
a278c42
Finalize example notebook for BERTCGAModel
Feb 19, 2025
9e5b3e7
Fix typos
Feb 19, 2025
2f2b168
Add leaderboard into README.md
Feb 19, 2025
41716a9
Add leaderboard into README.md
Feb 19, 2025
344e042
Fix typos
Feb 19, 2025
187c3e0
Add formatted string for Leaderboard
Feb 19, 2025
24268b9
require additional packages
Feb 19, 2025
ba44f57
Fix leaderboard_string
Feb 20, 2025
51719fb
add no-context setting
Feb 21, 2025
93935dd
Example for No-Context Setting
Feb 21, 2025
c286a3c
Fix typos
Feb 21, 2025
906b9a9
Fix typos
Feb 21, 2025
ad16ec4
Add No-Context Description
Feb 24, 2025
ac276f4
Add notebook for reproducing results
Feb 25, 2025
af62968
Add config for downloading BERT/RoBERTa CGA models
Feb 26, 2025
d0ce53f
Innital commit for zero-shot LLM (transform function)
Feb 27, 2025
cb74766
Fix typos
Feb 27, 2025
654a1d7
Finish code notebook for BERT/RoBERTa reproduction
Feb 27, 2025
170694d
Move BERT/RoBERTa to models section
Feb 27, 2025
4dc46f8
First commit LLMCGAModel.fit()
Feb 27, 2025
8a296c8
Clean Local Paths
Feb 28, 2025
908156d
Remove LLMCGAModel for pull request
Feb 28, 2025
899fb61
Clean forecaster.init
sonqt Feb 28, 2025
4d76e31
Clean forecaster.init
sonqt Feb 28, 2025
2b54b71
Clean BERTCGAModel
sonqt Feb 28, 2025
f8e18b4
Clean trailling spaces
sonqt Feb 28, 2025
fd88977
black formatting
seanzhangkx8 Feb 28, 2025
e1b6ac2
fixing black version to be same as workflow
seanzhangkx8 Feb 28, 2025
6e9be3f
run black ./
seanzhangkx8 Feb 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions convokit/coordination/coordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,10 @@ def _scores_over_utterances(
target = utt1.speaker
if speaker == target:
continue
speaker, target = Coordination._annot_speaker(
speaker, utt2, split_by_attribs
), Coordination._annot_speaker(target, utt1, split_by_attribs)
speaker, target = (
Coordination._annot_speaker(speaker, utt2, split_by_attribs),
Coordination._annot_speaker(target, utt1, split_by_attribs),
)

speaker_filter = speaker_utterance_selector(utt2, utt1)
target_filter = target_utterance_selector(utt2, utt1)
Expand Down
252 changes: 252 additions & 0 deletions convokit/forecaster/BERTCGAModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import os
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
from sklearn.metrics import roc_curve
from datasets import Dataset, DatasetDict
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
TrainingArguments,
Trainer,
)
from .forecasterModel import ForecasterModel
import shutil


os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEFAULT_CONFIG = {
"output_dir": "BERTCGAModel",
"per_device_batch_size": 4,
"num_train_epochs": 2,
"learning_rate": 6.7e-6,
"random_seed": 1,
"device": "cuda",
}


class BERTCGAModel(ForecasterModel):
"""
Wrapper for Huggingface Transformers AutoModel
"""

def __init__(self, model_name_or_path, config=DEFAULT_CONFIG):
super().__init__()
try:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
model_max_length=512,
truncation_side="left",
padding_side="right",
)
except:
# The checkpoint didn't save tokenizer
model_config_file = os.path.join(model_name_or_path, "config.json")
with open(model_config_file, "r") as file:
original_model = json.load(file)["_name_or_path"]
self.tokenizer = AutoTokenizer.from_pretrained(
original_model, model_max_length=512, truncation_side="left", padding_side="right"
)
self.best_threshold = None
model_config = AutoConfig.from_pretrained(
model_name_or_path, num_labels=2, problem_type="single_label_classification"
)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path, ignore_mismatched_sizes=True, config=model_config
).to(config["device"])
if not os.path.exists(config["output_dir"]):
os.makedirs(config["output_dir"])
self.config = config
return

def _tokenize(self, context):
tokenized_context = self.tokenizer.encode_plus(
text=f" {self.tokenizer.sep_token} ".join([u.text for u in context]),
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=512,
)
return tokenized_context

def _context_to_bert_data(self, contexts):
pairs = {"id": [], "input_ids": [], "attention_mask": [], "labels": []}
for context in contexts:
convo = context.current_utterance.get_conversation()
label = self.labeler(convo)

if ("context_mode" not in self.config) or self.config["context_mode"] == "normal":
context_utts = context.context
elif self.config["context_mode"] == "no-context":
context_utts = [context.current_utterance]
tokenized_context = self._tokenize(context_utts)
pairs["input_ids"].append(tokenized_context["input_ids"])
pairs["attention_mask"].append(tokenized_context["attention_mask"])
pairs["labels"].append(label)
pairs["id"].append(context.current_utterance.id)
return Dataset.from_dict(pairs)

@torch.inference_mode
@torch.no_grad
def _predict(
self,
dataset,
model=None,
threshold=0.5,
forecast_prob_attribute_name=None,
forecast_attribute_name=None,
):
"""
Return predictions in DataFrame
"""
if not forecast_prob_attribute_name:
forecast_prob_attribute_name = "score"
if not forecast_attribute_name:
forecast_attribute_name = "pred"
if not model:
model = self.model.to(self.config["device"])
utt_ids = []
preds = []
scores = []
for data in tqdm(dataset):
input_ids = (
data["input_ids"].to(self.config["device"], dtype=torch.long).reshape([1, -1])
)
attention_mask = (
data["attention_mask"].to(self.config["device"], dtype=torch.long).reshape([1, -1])
)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probs = F.softmax(outputs.logits, dim=-1)
utt_ids.append(data["id"])
raw_score = probs[0, 1].item()
preds.append(int(raw_score > threshold))
scores.append(raw_score)

return pd.DataFrame(
{forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids
)

def _tune_best_val_accuracy(self, val_dataset, val_contexts):
"""
Save the tuned model to self.best_threshold and self.model
"""
checkpoints = os.listdir(self.config["output_dir"])
best_val_accuracy = 0
val_convo_ids = set()
utt2convo = {}
val_labels_dict = {}
for context in val_contexts:
convo_id = context.conversation_id
utt_id = context.current_utterance.id
label = self.labeler(context.current_utterance.get_conversation())
utt2convo[utt_id] = convo_id
val_labels_dict[convo_id] = label
val_convo_ids.add(convo_id)
val_convo_ids = list(val_convo_ids)
for cp in checkpoints:
full_model_path = os.path.join(self.config["output_dir"], cp)
finetuned_model = AutoModelForSequenceClassification.from_pretrained(
full_model_path
).to(self.config["device"])
val_scores = self._predict(val_dataset, model=finetuned_model)
# for each CONVERSATION, whether or not it triggers will be effectively determined by what the highest score it ever got was
highest_convo_scores = {convo_id: -1 for convo_id in val_convo_ids}
count_correct = 0
for utt_id in val_scores.index:
count_correct += 1
convo_id = utt2convo[utt_id]
utt_score = val_scores.loc[utt_id].score
if utt_score > highest_convo_scores[convo_id]:
highest_convo_scores[convo_id] = utt_score

val_labels = np.asarray([int(val_labels_dict[c]) for c in val_convo_ids])
val_scores = np.asarray([highest_convo_scores[c] for c in val_convo_ids])
# use scikit learn to find candidate threshold cutoffs
_, _, thresholds = roc_curve(val_labels, val_scores)

def acc_with_threshold(y_true, y_score, thresh):
y_pred = (y_score > thresh).astype(int)
return (y_pred == y_true).mean()

accs = [acc_with_threshold(val_labels, val_scores, t) for t in thresholds]
best_acc_idx = np.argmax(accs)

print("Accuracy:", cp, accs[best_acc_idx])
if accs[best_acc_idx] > best_val_accuracy:
best_checkpoint = cp
best_val_accuracy = accs[best_acc_idx]
self.best_threshold = thresholds[best_acc_idx]
self.model = finetuned_model

eval_forecasts_df = self._predict(val_dataset, threshold=self.best_threshold)
eval_prediction_file = os.path.join(self.config["output_dir"], "val_predictions.csv")
eval_forecasts_df.to_csv(eval_prediction_file)

# Save the best config
best_config = {}
best_config["best_checkpoint"] = best_checkpoint
best_config["best_threshold"] = self.best_threshold
best_config["best_val_accuracy"] = best_val_accuracy
config_file = os.path.join(self.config["output_dir"], "dev_config.json")
with open(config_file, "w") as outfile:
json_object = json.dumps(best_config, indent=4)
outfile.write(json_object)

# Clean other checkpoints to save disk space.
for root, _, _ in os.walk(self.config["output_dir"]):
if ("checkpoint" in root) and (best_checkpoint not in root):
print("Deleting:", root)
shutil.rmtree(root)
return

def fit(self, contexts, val_contexts):
"""
Description: Train the conversational forecasting model on the given data
Parameters:
contexts: an iterator over context tuples, as defined by the above data format
val_contexts: an optional second iterator over context tuples to be used as a separate held-out validation set.
The generator for this must be the same as test generator
"""
val_contexts = list(val_contexts)
train_pairs = self._context_to_bert_data(contexts)
val_for_tuning_pairs = self._context_to_bert_data(val_contexts)
dataset = DatasetDict({"train": train_pairs, "val_for_tuning": val_for_tuning_pairs})
dataset.set_format("torch")

training_args = TrainingArguments(
output_dir=self.config["output_dir"],
per_device_train_batch_size=self.config["per_device_batch_size"],
num_train_epochs=self.config["num_train_epochs"],
learning_rate=self.config["learning_rate"],
logging_strategy="epoch",
weight_decay=0.01,
eval_strategy="no",
save_strategy="epoch",
prediction_loss_only=False,
seed=self.config["random_seed"],
)
trainer = Trainer(model=self.model, args=training_args, train_dataset=dataset["train"])
trainer.train()

self._tune_best_val_accuracy(dataset["val_for_tuning"], val_contexts)
return

def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name):
test_pairs = self._context_to_bert_data(contexts)
dataset = DatasetDict({"test": test_pairs})
dataset.set_format("torch")
forecasts_df = self._predict(
dataset["test"],
threshold=self.best_threshold,
forecast_attribute_name=forecast_attribute_name,
forecast_prob_attribute_name=forecast_prob_attribute_name,
)

prediction_file = os.path.join(self.config["output_dir"], "test_predictions.csv")
forecasts_df.to_csv(prediction_file)

return forecasts_df
13 changes: 13 additions & 0 deletions convokit/forecaster/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
**Table 1: Forecasting derailment on \newcmv conversations.**
The performance is measured in accuracy (Acc), precision (P), recall (R), F1, false positive rate (FPR), mean horizon (Mean H), and Forecast Recovery (Recovery) along with the correct and incorrect adjustment rates. The best performance across each metric is indicated in **bold**.
| Model | Acc ↑ | P ↑ | R ↑ | F1 ↑ | FPR ↓ | Mean H ↑ | Recovery ↑ (CA/N - IA/N) |
|--------------------------------|--------|-------|-------|-------|--------|---------|-------------------------|
| Human (84 convos) round-1 | 62.2 | 67.8 | 48.9 | 54.6 | 24.4 | 3.64 | - |
| Human (84 convos) round-2 | 70.0 | 75.9 | 55.6 | 63.9 | 15.6 | 3.13 | - |
| RoBERTa-large | **68.4** | 67.5 | 71.1 | 69.2 | 34.3 | 4.14 | +1.1 (7.2 - 6.1) |
| Gemma-2 27B-IT (finetuned) | **68.4** | 66.2 | 75.2 | **70.4** | 38.5 | 4.30 | +0.0 (10.7 - 10.7) |
| GPT-4o (12/2024; zero-shot) | 66.6 | **71.0** | 56.3 | 62.8 | **23.0** | 3.78 | -1.5 (5.9 - 7.4) |
| BERT-base | 65.2 | 63.5 | 72.0 | 67.4 | 41.6 | 4.45 | +2.1 (9.8 - 7.7) |
| CRAFT | 62.8 | 59.4 | 81.1 | 68.5 | 55.5 | 4.69 | +4.9 (12.0 - 7.1) |
| Gemma-2 27B-IT (zero-shot) | 59.4 | 55.7 | **92.2** | 69.4 | 73.5 | **5.27** | **+7.1** (12.2 - 5.1) |

1 change: 1 addition & 0 deletions convokit/forecaster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@

if "torch" in sys.modules:
from .CRAFTModel import *
from .BERTCGAModel import *
from .CRAFT import *
46 changes: 40 additions & 6 deletions convokit/forecaster/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def summarize(
"label": [],
"score": [],
"forecast": [],
"last_utterance_forecast": [],
}
for convo in corpus.iter_conversations():
if selector(convo):
Expand All @@ -237,6 +238,7 @@ def summarize(
)
conversational_forecasts_df["score"].append(np.max(forecast_scores))
conversational_forecasts_df["forecast"].append(np.max(forecasts))
conversational_forecasts_df["last_utterance_forecast"].append(forecasts[-1])
conversational_forecasts_df = pd.DataFrame(conversational_forecasts_df).set_index(
"conversation_id"
)
Expand All @@ -260,21 +262,53 @@ def summarize(
(conversational_forecasts_df["label"] == 1)
& (conversational_forecasts_df["forecast"] == 0)
).sum()
# Correct Adjustments
ca = (
(conversational_forecasts_df["label"] == 0)
& (conversational_forecasts_df["forecast"] == 1)
& (conversational_forecasts_df["last_utterance_forecast"] == 0)
).mean()
# Incorrect Adjustments
ia = (
(conversational_forecasts_df["label"] == 1)
& (conversational_forecasts_df["forecast"] == 1)
& (conversational_forecasts_df["last_utterance_forecast"] == 0)
).mean()

p = tp / (tp + fp)
r = tp / (tp + fn)
fpr = fp / (fp + tn)
f1 = 2 / (((tp + fp) / tp) + ((tp + fn) / tp))
metrics = {"Accuracy": acc, "Precision": p, "Recall": r, "FPR": fpr, "F1": f1}

print(pd.Series(metrics))

comments_until_end = self._draw_horizon_plot(corpus, selector)
comments_until_end_vals = list(comments_until_end.values())
mean_h = np.mean(comments_until_end_vals) - 1
print(
"Horizon statistics (# of comments between first positive forecast and conversation end):"
)
print(
f"Mean = {np.mean(comments_until_end_vals)}, Median = {np.median(comments_until_end_vals)}"
)
print(f"Mean = {mean_h}, Median = {np.median(comments_until_end_vals) - 1}")

leaderboard_string = (
f"| MODEL_NAME | "
f"{acc*100:.1f} | "
f"{p*100:.1f} | "
f"{r*100:.1f} | "
f"{f1*100:.1f} | "
f"{fpr*100:.1f} | "
f"{mean_h:.2f} | "
f"{(ca-ia)*100:.1f} ({ca*100:.1f} - {ia*100:.1f}) |"
)
metrics = {
"Accuracy": acc,
"Precision": p,
"Recall": r,
"FPR": fpr,
"F1": f1,
"Mean H": mean_h,
"Correct Adjustment": ca,
"Incorrect Adjustment": ia,
"Recovery": ca - ia,
"Leaderboard String": leaderboard_string,
}
print(pd.Series(metrics))
return conversational_forecasts_df, metrics
8 changes: 7 additions & 1 deletion download_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@
"https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/craft_full.tar",
"https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/index2word.json",
"https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/word2index.json"
]
],
"cga-cmv-large/roberta-large": ["https://zissou.infosci.cornell.edu/convokit/models/forecaster_models/cga-cmv-large/roberta-large.tar"],
"cga-cmv-large/bert-base-cased": ["https://zissou.infosci.cornell.edu/convokit/models/forecaster_models/cga-cmv-large/bert-base-cased.tar"],
"cga-cmv-legacy/roberta-large": ["https://zissou.infosci.cornell.edu/convokit/models/forecaster_models/cga-cmv-legacy/roberta-large.tar"],
"cga-cmv-legacy/bert-base-cased": ["https://zissou.infosci.cornell.edu/convokit/models/forecaster_models/cga-cmv-legacy/bert-base-cased.tar"],
"cga-wikiconv/roberta-large": ["https://zissou.infosci.cornell.edu/convokit/models/forecaster_models/cga-wikiconv/roberta-large.tar"],
"cga-wikiconv/bert-base-cased": ["https://zissou.infosci.cornell.edu/convokit/models/forecaster_models/cga-wikiconv/bert-base-cased.tar"]
}
}
Loading