|
| 1 | +import re |
1 | 2 | from rich import print |
2 | 3 | from rich.console import Console |
3 | 4 | from rich.markdown import Markdown |
| 5 | +from rich.syntax import Syntax |
4 | 6 |
|
5 | 7 |
|
6 | 8 | def prettify_llm_output(response): |
7 | 9 | """ |
8 | | - Prettifies the output from a language model response by stripping leading |
9 | | - and trailing whitespace and code block markers, then prints it as Markdown |
10 | | - to the console. |
| 10 | + Prettifies the output from a language model response by processing it |
| 11 | + as Markdown with enhanced code block detection and syntax highlighting. |
11 | 12 |
|
12 | 13 | Args: |
13 | 14 | response (str): The raw response from the language model. |
14 | 15 |
|
15 | 16 | Returns: |
16 | 17 | None |
17 | 18 | """ |
18 | | - markdown_output = response.strip().strip("```") |
19 | 19 | console = Console() |
| 20 | + |
| 21 | + # Clean up the response |
| 22 | + markdown_output = response.strip() |
| 23 | + |
| 24 | + # Check if the entire response is just code without markdown formatting |
| 25 | + if _is_likely_code(markdown_output) and not _has_markdown_elements(markdown_output): |
| 26 | + # Treat the entire response as Python code |
| 27 | + syntax = Syntax(markdown_output, "python", theme="monokai", line_numbers=False) |
| 28 | + print() |
| 29 | + console.print(syntax) |
| 30 | + print() |
| 31 | + return |
| 32 | + |
| 33 | + # Check for code blocks and ensure they have language specification |
| 34 | + markdown_output = _enhance_code_blocks(markdown_output) |
| 35 | + |
| 36 | + # Render as markdown |
20 | 37 | md = Markdown(markdown_output) |
21 | 38 | print() |
22 | 39 | console.print(md) |
23 | 40 | print() |
| 41 | + |
| 42 | + |
| 43 | +def _is_likely_code(text): |
| 44 | + """Check if text looks like Python code.""" |
| 45 | + python_keywords = [ |
| 46 | + "import", |
| 47 | + "def", |
| 48 | + "class", |
| 49 | + "if", |
| 50 | + "for", |
| 51 | + "while", |
| 52 | + "try", |
| 53 | + "except", |
| 54 | + "plt.", |
| 55 | + ] |
| 56 | + lines = text.split("\n") |
| 57 | + code_indicators = 0 |
| 58 | + |
| 59 | + for line in lines: |
| 60 | + line = line.strip() |
| 61 | + if any(keyword in line for keyword in python_keywords): |
| 62 | + code_indicators += 1 |
| 63 | + if line.startswith("#"): # Comments |
| 64 | + code_indicators += 1 |
| 65 | + |
| 66 | + return code_indicators >= 2 or any( |
| 67 | + keyword in text for keyword in ["plt.", "matplotlib", "import"] |
| 68 | + ) |
| 69 | + |
| 70 | + |
| 71 | +def _has_markdown_elements(text): |
| 72 | + """Check if text contains markdown elements.""" |
| 73 | + markdown_indicators = ["#", "**", "*", "`", ">", "-", "1.", "[", "]"] |
| 74 | + return any(indicator in text for indicator in markdown_indicators) |
| 75 | + |
| 76 | + |
| 77 | +def _enhance_code_blocks(text): |
| 78 | + """Add language specification to code blocks that don't have it.""" |
| 79 | + # Pattern to find code blocks without language specification |
| 80 | + pattern = r"```\n(.*?)\n```" |
| 81 | + |
| 82 | + def replace_code_block(match): |
| 83 | + code_content = match.group(1) |
| 84 | + if _is_likely_code(code_content): |
| 85 | + return f"```python\n{code_content}\n```" |
| 86 | + return match.group(0) |
| 87 | + |
| 88 | + # Replace code blocks without language specification |
| 89 | + enhanced = re.sub(pattern, replace_code_block, text, flags=re.DOTALL) |
| 90 | + |
| 91 | + return enhanced |
0 commit comments