Skip to content

Commit caa3981

Browse files
committed
bug hunting for risk score extraction from reasoning models
1 parent 6de2344 commit caa3981

File tree

4 files changed

+489
-115
lines changed

4 files changed

+489
-115
lines changed

folktexts/classifier/transformers_classifier.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import logging
67
from pathlib import Path
78
from typing import Callable
89

@@ -138,11 +139,24 @@ def _query_prompt_risk_estimates_batch(
138139
enable_thinking=question.enable_thinking,
139140
)
140141

141-
# Extract probability from generated text
142-
risk_estimates_batch = [
143-
question.get_answer_from_model_output(generated_text)
144-
for generated_text in generated_texts
145-
]
142+
# Extract probability from generated text and log each sample
143+
risk_estimates_batch = []
144+
for idx, (prompt, generated_text) in enumerate(zip(prompts_batch, generated_texts)):
145+
risk_estimate = question.get_answer_from_model_output(generated_text)
146+
risk_estimates_batch.append(risk_estimate)
147+
148+
# Log prompt, generated answer, and extracted risk score at INFO level
149+
logging.info(
150+
f"\n{'='*60}\n"
151+
f"[ReasoningQA Sample {idx + 1}/{len(prompts_batch)}]\n"
152+
f"{'='*60}\n"
153+
f"PROMPT:\n{prompt}\n"
154+
f"{'-'*60}\n"
155+
f"GENERATED ANSWER:\n{generated_text}\n"
156+
f"{'-'*60}\n"
157+
f"EXTRACTED RISK SCORE: {risk_estimate:.4f}\n"
158+
f"{'='*60}"
159+
)
146160

147161
return risk_estimates_batch
148162

folktexts/llm_utils.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ def generate_text_batch(
173173
The maximum context size for input tokens. If None, no truncation
174174
is applied to inputs.
175175
enable_thinking : bool, optional
176-
Whether to enable thinking mode for models that support it (e.g., Qwen3).
177-
When True, uses `tokenizer.apply_chat_template` with `enable_thinking=True`.
178-
When False, explicitly disables thinking mode. When None (default),
179-
does not apply chat template formatting.
176+
Controls chat template application and thinking mode:
177+
- None: Do not apply chat template (use raw prompts, for base models)
178+
- False: Apply chat template WITHOUT thinking mode (for instruction-tuned models)
179+
- True: Apply chat template WITH thinking mode, and extract response
180+
content after </think> marker (for thinking models like Qwen3)
180181
181182
Returns
182183
-------
@@ -192,7 +193,10 @@ def generate_text_batch(
192193
tokenizer.padding_side = "left"
193194

194195
try:
195-
# Apply chat template if enable_thinking is specified
196+
# Apply chat template when enable_thinking is not None
197+
# - enable_thinking=True: apply with thinking enabled
198+
# - enable_thinking=False: apply without thinking (standard chat format)
199+
# - enable_thinking=None: skip chat template (raw prompts for base models)
196200
if enable_thinking is not None:
197201
processed_inputs = []
198202
for text in text_inputs:
@@ -208,17 +212,20 @@ def generate_text_batch(
208212
processed_inputs.append(formatted_text)
209213
except TypeError:
210214
# Tokenizer doesn't support enable_thinking parameter
211-
logging.warning(
212-
"Tokenizer does not support 'enable_thinking' parameter. "
213-
"Falling back to standard chat template."
214-
)
215+
# This is expected for non-Qwen models
216+
if enable_thinking:
217+
logging.warning(
218+
"Tokenizer does not support 'enable_thinking' parameter. "
219+
"Falling back to standard chat template."
220+
)
215221
formatted_text = tokenizer.apply_chat_template(
216222
messages,
217223
tokenize=False,
218224
add_generation_prompt=True,
219225
)
220226
processed_inputs.append(formatted_text)
221227
text_inputs = processed_inputs
228+
logging.debug(f"Applied chat template (enable_thinking={enable_thinking})")
222229

223230
# Tokenize inputs with left-padding for generation
224231
tokenized = tokenizer(
@@ -249,40 +256,55 @@ def generate_text_batch(
249256
generated_texts = []
250257
for i, output in enumerate(outputs):
251258
# Extract only the newly generated tokens (after the padded input)
252-
generated_tokens = output[input_seq_length:].tolist()
259+
generated_tokens = output[input_seq_length:]
253260

254-
# If thinking mode was enabled, separate thinking content from response content
255-
# The </think> token (ID 151668) marks the end of thinking content
256-
if enable_thinking:
257-
thinking_end_token_id = 151668 # </think> token ID for Qwen models
258-
try:
259-
# Find the </think> token from the end (in case there are multiple)
260-
index = len(generated_tokens) - generated_tokens[::-1].index(thinking_end_token_id)
261-
# Only decode content after </think>
262-
content_tokens = generated_tokens[index:]
263-
thinking_tokens = generated_tokens[:index]
261+
# Decode the full generated text
262+
full_generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
264263

265-
thinking_content = tokenizer.decode(thinking_tokens, skip_special_tokens=True).strip("\n")
266-
content = tokenizer.decode(content_tokens, skip_special_tokens=True).strip("\n")
264+
# If thinking mode was enabled, separate thinking content from response
265+
if enable_thinking is True:
266+
# Use string-based detection for </think> separator
267+
# This is more robust than relying on hardcoded token IDs
268+
think_end_marker = "</think>"
267269

268-
# Log all decoded tokens at debug level
270+
if think_end_marker in full_generated_text:
271+
# Split on </think> and take only the response content
272+
# Thinking content is logged but IGNORED for probability extraction
273+
parts = full_generated_text.split(think_end_marker, 1)
274+
thinking_content = parts[0].strip()
275+
response_content = parts[1].strip() if len(parts) > 1 else ""
276+
277+
# Log thinking content for debugging (but don't use it for extraction)
269278
logging.debug(f"=== Generated output {i+1}/{len(outputs)} ===")
270-
logging.debug(f"Thinking content ({len(thinking_content)} chars):\n{thinking_content}")
271-
logging.debug(f"Response content ({len(content)} chars):\n{content}")
272-
273-
generated_texts.append(content)
274-
except ValueError:
275-
# </think> token not found - decode entire output
276-
logging.warning("</think> token not found in output. Using full generated text.")
277-
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
278-
logging.debug(f"=== Generated output {i+1}/{len(outputs)} (no thinking separation) ===")
279-
logging.debug(f"Full content ({len(generated_text)} chars):\n{generated_text}")
280-
generated_texts.append(generated_text)
279+
logging.debug(f"Thinking content ({len(thinking_content)} chars) [IGNORED for extraction]:")
280+
logging.debug(f"{thinking_content[:500]}..." if len(thinking_content) > 500 else thinking_content)
281+
logging.debug(f"Response content ({len(response_content)} chars) [USED for extraction]:")
282+
logging.debug(response_content)
283+
284+
# Always use response content only - thinking content is ignored
285+
if response_content:
286+
generated_texts.append(response_content)
287+
else:
288+
# Response content is empty - this is a problem
289+
logging.warning(
290+
"Response content after </think> is empty. "
291+
"Model may not have generated a proper response. "
292+
"Probability extraction will likely fail."
293+
)
294+
generated_texts.append("")
295+
else:
296+
# </think> marker not found - use full text
297+
# This can happen if the model doesn't actually use thinking format
298+
logging.warning(
299+
f"</think> marker not found in output (thinking mode was enabled). "
300+
f"Using full generated text ({len(full_generated_text)} chars)."
301+
)
302+
generated_texts.append(full_generated_text.strip())
281303
else:
282-
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
304+
# Non-thinking mode: use full generated text
283305
logging.debug(f"=== Generated output {i+1}/{len(outputs)} ===")
284-
logging.debug(f"Content ({len(generated_text)} chars):\n{generated_text}")
285-
generated_texts.append(generated_text)
306+
logging.debug(f"Content ({len(full_generated_text)} chars):\n{full_generated_text[:500]}...")
307+
generated_texts.append(full_generated_text.strip())
286308

287309
return generated_texts
288310

folktexts/qa_interface.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -427,18 +427,18 @@ def get_answer_from_model_output(
427427

428428
# Regex patterns for extracting probability from generated text
429429
# Matches formats like: "Probability: 80%", "Probability: 0.80", "probability: 80 percent"
430+
# Patterns are ordered by specificity - more specific patterns first
430431
_PROBABILITY_PATTERNS = [
431-
# Match "Probability: X%" or "probability: X%"
432-
r"[Pp]robability:\s*(\d+(?:\.\d+)?)\s*%",
433-
# Match "Probability: 0.XX" or "probability: 0.XX"
434-
r"[Pp]robability:\s*(0?\.\d+)",
435-
# Match "Probability: X" where X is a whole number (interpreted as percentage)
436-
r"[Pp]robability:\s*(\d+)\s*(?:percent|%)",
437-
# Match standalone percentage at end of text: "... 80%" or "...0.80"
438-
r"(\d+(?:\.\d+)?)\s*%\s*$",
439-
r"(0?\.\d+)\s*$",
432+
# Match "Probability: X%" or "probability: X%" (with optional "is", "of", etc.)
433+
r"[Pp]robability(?:\s+(?:is|of|estimate)?)?[:\s]+(\d+(?:\.\d+)?)\s*%",
434+
# Match "Probability: 0.XX" or "probability: 0.XX" or "Probability: 1.0"
435+
r"[Pp]robability(?:\s+(?:is|of|estimate)?)?[:\s]+(\d*\.?\d+)(?![%\d])",
436+
# Match "X%" anywhere in text (prefer later matches in fallback)
437+
r"(\d+(?:\.\d+)?)\s*%",
440438
# Match "X percent" pattern
441439
r"(\d+(?:\.\d+)?)\s+percent",
440+
# Match standalone decimal (0.XX or .XX) that looks like probability
441+
r"(?<![.\d])(0?\.\d+)(?![.\d])",
442442
]
443443

444444

@@ -514,11 +514,25 @@ def extract_probability_from_text(generated_text: str) -> float | None:
514514
probability : float | None
515515
The extracted probability as a float between 0 and 1, or None if
516516
no valid probability was found.
517+
518+
Notes
519+
-----
520+
The extraction prioritizes:
521+
1. Explicit "Probability: X%" format (most reliable)
522+
2. Last percentage or probability value in text (likely the conclusion)
523+
3. Fallback to any decimal that looks like a probability
517524
"""
518-
# Try each pattern in order of specificity
519-
for pattern in _PROBABILITY_PATTERNS:
520-
match = re.search(pattern, generated_text)
521-
if match:
525+
# First, try to find explicit "Probability: X" format (most reliable)
526+
explicit_patterns = [
527+
r"[Pp]robability(?:\s+(?:is|of|estimate)?)?[:\s]+(\d+(?:\.\d+)?)\s*%",
528+
r"[Pp]robability(?:\s+(?:is|of|estimate)?)?[:\s]+(\d*\.?\d+)(?![%\d])",
529+
]
530+
531+
for pattern in explicit_patterns:
532+
# Find ALL matches and use the LAST one (likely the final answer)
533+
matches = list(re.finditer(pattern, generated_text))
534+
if matches:
535+
match = matches[-1] # Use last match
522536
value = float(match.group(1))
523537

524538
# Convert percentage to probability if > 1
@@ -532,24 +546,38 @@ def extract_probability_from_text(generated_text: str) -> float | None:
532546
else:
533547
logging.warning(f"Extracted value {value} is out of range [0, 1]")
534548

535-
# Fallback: try to find any number that could be a probability
536-
# Look for decimal numbers between 0 and 1
537-
decimal_matches = re.findall(r"0?\.\d+", generated_text)
538-
for match in reversed(decimal_matches): # Prefer later matches (likely the conclusion)
539-
value = float(match)
549+
# Second, look for percentage patterns (prefer last occurrence)
550+
percent_matches = re.findall(r"(\d+(?:\.\d+)?)\s*%", generated_text)
551+
if percent_matches:
552+
# Use the last percentage found (likely the final answer)
553+
value = float(percent_matches[-1]) / 100.0
540554
if 0 <= value <= 1:
541-
logging.warning(f"Using fallback decimal extraction: {value:.2%}")
555+
logging.debug(f"Using fallback percentage extraction: {value:.2%}")
542556
return value
543557

544-
# Look for percentages
545-
percent_matches = re.findall(r"(\d+(?:\.\d+)?)\s*%", generated_text)
546-
for match in reversed(percent_matches):
547-
value = float(match) / 100.0
558+
# Third, look for "X percent" pattern
559+
percent_word_matches = re.findall(r"(\d+(?:\.\d+)?)\s+percent", generated_text, re.IGNORECASE)
560+
if percent_word_matches:
561+
value = float(percent_word_matches[-1]) / 100.0
548562
if 0 <= value <= 1:
549-
logging.warning(f"Using fallback percentage extraction: {value:.2%}")
563+
logging.debug(f"Using fallback 'X percent' extraction: {value:.2%}")
550564
return value
551565

552-
logging.error(f"Could not extract probability from text: {generated_text[:200]}...")
566+
# Fourth, try to find decimal numbers between 0 and 1
567+
decimal_matches = re.findall(r"(?<![.\d])(0?\.\d+)(?![.\d])", generated_text)
568+
if decimal_matches:
569+
# Use the last decimal found
570+
value = float(decimal_matches[-1])
571+
if 0 <= value <= 1:
572+
logging.warning(f"Using fallback decimal extraction: {value:.2%}")
573+
return value
574+
575+
# Log a detailed error message for debugging
576+
if len(generated_text) > 500:
577+
snippet = generated_text[:250] + "..." + generated_text[-250:]
578+
else:
579+
snippet = generated_text
580+
logging.error(f"Could not extract probability from text:\n{snippet}")
553581
return None
554582

555583
def get_answer_from_model_output(

0 commit comments

Comments
 (0)