Skip to content

Commit a8ad9bb

Browse files
committed
finished version | type errors fixed
1 parent 78e9f6d commit a8ad9bb

File tree

1 file changed

+60
-55
lines changed

1 file changed

+60
-55
lines changed

src/flare_ai_kit/consensus/aggregator/advanced_strategies.py

Lines changed: 60 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Advanced consensus strategies for detecting hallucinations and improving robustness."""
22

33
import numpy as np
4-
from collections import Counter
5-
from typing import List, Tuple, Dict, Any, Optional
4+
from typing import List, Dict, Any, Union
65
from dataclasses import dataclass
76
from sklearn.cluster import DBSCAN, KMeans
8-
from sklearn.metrics.pairwise import cosine_similarity
7+
from sklearn.metrics.pairwise import cosine_similarity # type: ignore
98
from sklearn.preprocessing import StandardScaler
109
import logging
1110

@@ -22,8 +21,8 @@ class ClusterResult:
2221
dominant_cluster: List[Prediction]
2322
outlier_clusters: List[List[Prediction]]
2423
cluster_labels: List[int]
25-
similarity_matrix: np.ndarray
26-
centroid_embeddings: Dict[int, np.ndarray]
24+
similarity_matrix: np.ndarray[Any, np.dtype[np.float64]]
25+
centroid_embeddings: Dict[int, np.ndarray[Any, np.dtype[np.float64]]]
2726

2827

2928
def semantic_clustering_strategy(
@@ -54,11 +53,11 @@ def semantic_clustering_strategy(
5453
# Generate embeddings for all predictions
5554
texts = [str(p.prediction) for p in predictions]
5655
embeddings = embedding_model.embed_content(texts)
57-
embeddings_array = np.array(embeddings)
56+
embeddings_array = np.array(embeddings, dtype=np.float64)
5857

5958
# Normalize embeddings for better clustering
6059
scaler = StandardScaler()
61-
embeddings_normalized = scaler.fit_transform(embeddings_array)
60+
embeddings_normalized = scaler.fit_transform(embeddings_array).astype(np.float64) # type: ignore
6261

6362
# Perform clustering
6463
if clustering_method.lower() == "dbscan":
@@ -74,30 +73,27 @@ def semantic_clustering_strategy(
7473
else:
7574
raise ValueError(f"Unsupported clustering method: {clustering_method}")
7675

77-
cluster_labels = clustering.fit_predict(embeddings_normalized)
76+
cluster_labels: np.ndarray[Any, np.dtype[np.int64]] = clustering.fit_predict(embeddings_normalized) # type: ignore
7877

7978
# Group predictions by cluster
8079
clusters: Dict[int, List[Prediction]] = {}
8180
for i, label in enumerate(cluster_labels):
82-
if label not in clusters:
83-
clusters[label] = []
84-
clusters[label].append(predictions[i])
81+
label_int = int(label)
82+
if label_int not in clusters:
83+
clusters[label_int] = []
84+
clusters[label_int].append(predictions[i])
8585

8686
# Find dominant cluster (largest cluster)
8787
dominant_cluster_label = max(clusters.keys(), key=lambda k: len(clusters[k]))
8888
dominant_cluster = clusters[dominant_cluster_label]
89-
outlier_clusters = [clusters[k] for k in clusters.keys() if k != dominant_cluster_label]
90-
91-
# Calculate centroid for dominant cluster
92-
dominant_embeddings = embeddings_array[cluster_labels == dominant_cluster_label]
93-
centroid = np.mean(dominant_embeddings, axis=0)
9489

9590
# Select best prediction from dominant cluster (highest confidence)
9691
best_prediction = max(dominant_cluster, key=lambda p: p.confidence)
9792

9893
# Calculate consensus confidence based on cluster stability
99-
cluster_similarities = cosine_similarity(dominant_embeddings)
100-
avg_similarity = np.mean(cluster_similarities[np.triu_indices_from(cluster_similarities, k=1)])
94+
dominant_embeddings = embeddings_array[cluster_labels == dominant_cluster_label]
95+
cluster_similarities: np.ndarray[Any, np.dtype[np.float64]] = cosine_similarity(dominant_embeddings) # type: ignore
96+
avg_similarity = float(np.mean(cluster_similarities[np.triu_indices_from(cluster_similarities, k=1)]))
10197

10298
# Adjust confidence based on cluster quality
10399
adjusted_confidence = min(best_prediction.confidence * avg_similarity, 1.0)
@@ -131,22 +127,22 @@ def shapley_value_strategy(
131127
# Generate embeddings
132128
texts = [str(p.prediction) for p in predictions]
133129
embeddings = embedding_model.embed_content(texts)
134-
embeddings_array = np.array(embeddings)
130+
embeddings_array = np.array(embeddings, dtype=np.float64)
135131

136132
# Calculate pairwise similarities
137-
similarity_matrix = cosine_similarity(embeddings_array)
133+
similarity_matrix: np.ndarray[Any, np.dtype[np.float64]] = cosine_similarity(embeddings_array)
138134

139135
# Monte Carlo approximation of Shapley values
140-
shapley_values = np.zeros(len(predictions))
136+
shapley_values: np.ndarray[Any, np.dtype[np.float64]] = np.zeros(len(predictions), dtype=np.float64)
141137

142138
for _ in range(n_samples):
143139
# Random permutation of agents
144-
permutation = np.random.permutation(len(predictions))
140+
permutation: np.ndarray[Any, np.dtype[np.int64]] = np.random.permutation(len(predictions))
145141

146142
# Calculate marginal contributions
147-
current_set = set()
143+
current_set: set[int] = set()
148144
for i, agent_idx in enumerate(permutation):
149-
current_set.add(agent_idx)
145+
current_set.add(int(agent_idx))
150146

151147
# Calculate utility of current set
152148
if len(current_set) == 1:
@@ -155,14 +151,14 @@ def shapley_value_strategy(
155151
# Calculate average similarity within the set
156152
set_indices = list(current_set)
157153
set_similarities = similarity_matrix[np.ix_(set_indices, set_indices)]
158-
utility = np.mean(set_similarities[np.triu_indices_from(set_similarities, k=1)])
154+
utility = float(np.mean(set_similarities[np.triu_indices_from(set_similarities, k=1)]))
159155

160156
# Calculate marginal contribution
161157
if i == 0:
162158
marginal_contribution = utility
163159
else:
164160
# Calculate utility without this agent
165-
prev_set = current_set - {agent_idx}
161+
prev_set = current_set - {int(agent_idx)}
166162
if len(prev_set) == 0:
167163
prev_utility = 0.0
168164
else:
@@ -171,38 +167,41 @@ def shapley_value_strategy(
171167
if len(prev_indices) == 1:
172168
prev_utility = 1.0
173169
else:
174-
prev_utility = np.mean(prev_similarities[np.triu_indices_from(prev_similarities, k=1)])
170+
prev_utility = float(np.mean(prev_similarities[np.triu_indices_from(prev_similarities, k=1)]))
175171

176172
marginal_contribution = utility - prev_utility
177173

178-
shapley_values[agent_idx] += marginal_contribution
174+
shapley_values[int(agent_idx)] += marginal_contribution
179175

180176
# Normalize Shapley values
181177
shapley_values /= n_samples
182178

183179
# Weight predictions by Shapley values
184-
total_weight = np.sum(shapley_values)
180+
total_weight = float(np.sum(shapley_values))
185181
if total_weight == 0:
186182
# Fallback to equal weighting
187-
weights = np.ones(len(predictions)) / len(predictions)
183+
weights: np.ndarray[Any, np.dtype[np.float64]] = np.ones(len(predictions), dtype=np.float64) / len(predictions)
188184
else:
189185
weights = shapley_values / total_weight
190186

191187
# For string predictions, use weighted voting
192188
if isinstance(predictions[0].prediction, str):
193-
vote_counts = Counter()
189+
vote_counts: Dict[str, float] = {}
194190
for pred, weight in zip(predictions, weights):
195-
vote_counts[str(pred.prediction)] += weight
191+
pred_str = str(pred.prediction)
192+
vote_counts[pred_str] = vote_counts.get(pred_str, 0.0) + float(weight)
196193

197-
consensus_prediction = vote_counts.most_common(1)[0][0]
194+
consensus_prediction = max(vote_counts.items(), key=lambda x: x[1])[0]
195+
print(f"Consensus prediction: {consensus_prediction}")
198196
else:
199197
# For numerical predictions, use weighted average
200198
consensus_prediction = sum(
201-
float(p.prediction) * weight for p, weight in zip(predictions, weights)
199+
float(p.prediction) * float(weight) for p, weight in zip(predictions, weights)
202200
)
201+
print(f"Consensus prediction: {consensus_prediction}")
203202

204203
# Calculate weighted confidence
205-
weighted_confidence = sum(p.confidence * weight for p, weight in zip(predictions, weights))
204+
weighted_confidence = sum(p.confidence * float(weight) for p, weight in zip(predictions, weights))
206205

207206
return Prediction(
208207
agent_id="shapley_consensus",
@@ -233,24 +232,24 @@ def entropy_based_strategy(
233232
# Generate embeddings
234233
texts = [str(p.prediction) for p in predictions]
235234
embeddings = embedding_model.embed_content(texts)
236-
embeddings_array = np.array(embeddings)
235+
embeddings_array = np.array(embeddings, dtype=np.float64)
237236

238237
# Calculate pairwise similarities
239-
similarity_matrix = cosine_similarity(embeddings_array)
238+
similarity_matrix: np.ndarray[Any, np.dtype[np.float64]] = cosine_similarity(embeddings_array)
240239

241240
# Calculate entropy of the similarity distribution
242-
similarities = similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)]
241+
similarities: np.ndarray[Any, np.dtype[np.float64]] = similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)]
243242
if len(similarities) > 0:
244243
# Normalize similarities to probabilities
245244
similarities = np.clip(similarities, 0, 1)
246245
similarities = similarities / np.sum(similarities)
247246

248247
# Calculate entropy
249-
entropy = -np.sum(similarities * np.log(similarities + 1e-10))
250-
max_entropy = np.log(len(similarities))
251-
normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0
248+
entropy = float(-np.sum(similarities * np.log(similarities + 1e-10)))
249+
max_entropy = float(np.log(len(similarities)))
250+
normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0
252251
else:
253-
normalized_entropy = 0
252+
normalized_entropy = 0.0
254253

255254
# Select prediction based on entropy
256255
if normalized_entropy > entropy_threshold:
@@ -260,20 +259,22 @@ def entropy_based_strategy(
260259
adjusted_confidence = best_prediction.confidence * (1 - normalized_entropy)
261260
else:
262261
# Low entropy: use similarity-weighted consensus
263-
weights = np.mean(similarity_matrix, axis=1)
262+
weights: np.ndarray[Any, np.dtype[np.float64]] = np.mean(similarity_matrix, axis=1)
264263
weights = weights / np.sum(weights)
265264

266265
if isinstance(predictions[0].prediction, str):
267-
vote_counts = Counter()
266+
vote_counts: Dict[str, float] = {}
268267
for pred, weight in zip(predictions, weights):
269-
vote_counts[str(pred.prediction)] += weight
268+
pred_str = str(pred.prediction)
269+
vote_counts[pred_str] = vote_counts.get(pred_str, 0.0) + float(weight)
270270

271-
consensus_prediction = vote_counts.most_common(1)[0][0]
271+
consensus_prediction = max(vote_counts.items(), key=lambda x: x[1])[0]
272+
print(f"Consensus prediction: {consensus_prediction}")
272273
else:
273274
consensus_prediction = sum(
274-
float(p.prediction) * weight for p, weight in zip(predictions, weights)
275+
float(p.prediction) * float(weight) for p, weight in zip(predictions, weights)
275276
)
276-
277+
print(f"Consensus prediction: {consensus_prediction}")
277278
best_prediction = predictions[np.argmax(weights)]
278279
adjusted_confidence = best_prediction.confidence * (1 - normalized_entropy)
279280

@@ -287,7 +288,7 @@ def entropy_based_strategy(
287288
def robust_consensus_strategy(
288289
predictions: List[Prediction],
289290
embedding_model: BaseEmbedding,
290-
strategies: List[str] = None
291+
strategies: Union[List[str], None] = None
291292
) -> Prediction:
292293
"""
293294
Robust consensus strategy that combines multiple approaches.
@@ -303,7 +304,7 @@ def robust_consensus_strategy(
303304
if strategies is None:
304305
strategies = ["semantic", "shapley", "entropy"]
305306

306-
strategy_results = []
307+
strategy_results: List[Prediction] = []
307308

308309
for strategy in strategies:
309310
try:
@@ -325,8 +326,11 @@ def robust_consensus_strategy(
325326
if not strategy_results:
326327
# Fallback to simple majority
327328
if isinstance(predictions[0].prediction, str):
328-
vote_counts = Counter(str(p.prediction) for p in predictions)
329-
consensus_prediction = vote_counts.most_common(1)[0][0]
329+
vote_counts: Dict[str, int] = {}
330+
for p in predictions:
331+
pred_str = str(p.prediction)
332+
vote_counts[pred_str] = vote_counts.get(pred_str, 0) + 1
333+
consensus_prediction = max(vote_counts.items(), key=lambda x: x[1])[0]
330334
else:
331335
consensus_prediction = sum(float(p.prediction) for p in predictions) / len(predictions)
332336

@@ -345,11 +349,12 @@ def robust_consensus_strategy(
345349
weights = [w / total_weight for w in weights]
346350

347351
if isinstance(strategy_results[0].prediction, str):
348-
vote_counts = Counter()
352+
strategy_vote_counts: Dict[str, float] = {}
349353
for result, weight in zip(strategy_results, weights):
350-
vote_counts[str(result.prediction)] += weight
354+
pred_str = str(result.prediction)
355+
strategy_vote_counts[pred_str] = strategy_vote_counts.get(pred_str, 0.0) + weight
351356

352-
consensus_prediction = vote_counts.most_common(1)[0][0]
357+
consensus_prediction = max(strategy_vote_counts.items(), key=lambda x: x[1])[0]
353358
else:
354359
consensus_prediction = sum(
355360
float(r.prediction) * weight for r, weight in zip(strategy_results, weights)

0 commit comments

Comments
 (0)