From 4de8cf3302a0d1c5f6a1bf230ca79838c72834ea Mon Sep 17 00:00:00 2001 From: Bing Wen Tan Date: Sun, 16 Feb 2025 21:36:26 +0800 Subject: [PATCH 1/2] bring isVideo and isaccessblocked upfront --- agents/gemini_agent.py | 43 +++++-- agents/openai_agent.py | 32 ++++- handlers/agent_generation.py | 4 +- prompts/agent.py | 10 +- prompts/preprocess_inputs.py | 71 +++++++++++ requirements.txt | 3 +- tests/tools/test_preprocess_inputs.py | 113 +++++++++++++++++ tools/__init__.py | 10 ++ tools/preprocess_inputs.py | 171 ++++++++++++++++++++++++++ tools/review_report.py | 20 +-- tools/summarise_report.py | 1 - utils/gemini_utils.py | 25 ++++ 12 files changed, 468 insertions(+), 35 deletions(-) create mode 100644 prompts/preprocess_inputs.py create mode 100644 tests/tools/test_preprocess_inputs.py create mode 100644 tools/preprocess_inputs.py diff --git a/agents/gemini_agent.py b/agents/gemini_agent.py index a22bdd0..3a245f5 100644 --- a/agents/gemini_agent.py +++ b/agents/gemini_agent.py @@ -3,10 +3,16 @@ from .abstract import FactCheckingAgentBaseClass from typing import Union, List from google.genai import types -from utils.gemini_utils import get_image_part, generate_image_parts, generate_text_parts +from utils.gemini_utils import ( + get_image_part, + generate_image_parts, + generate_text_parts, + generate_screenshot_parts, +) import asyncio import time -from tools import summarise_report_factory +from tools import summarise_report_factory, preprocess_inputs +from tools.preprocess_inputs import get_gemini_content import json from logger import StructuredLogger from langfuse.decorators import observe, langfuse_context @@ -249,10 +255,7 @@ async def generate_report(self, starting_parts): remaining_searches=self.remaining_searches, remaining_screenshots=self.remaining_screenshots, ) - if first_step: - available_functions = ["infer_intent"] - think = False - elif think and self.include_planning_step: + if think and self.include_planning_step: available_functions = ["plan_next_step"] else: banned_functions = ["plan_next_step", "infer_intent"] @@ -308,7 +311,6 @@ async def generate_report(self, starting_parts): ) if len(function_call_promises) == 0: think = not think - first_step = False continue function_results = await asyncio.gather(*function_call_promises) response_parts = GeminiAgent.flatten_and_organise(function_results) @@ -329,7 +331,6 @@ async def generate_report(self, starting_parts): return return_dict messages.append(types.Content(parts=response_parts, role="user")) think = not think - first_step = False logger.error("Report couldn't be generated after 50 turns") return { "error": "Report couldn't be generated after 50 turns", @@ -380,6 +381,25 @@ async def generate_note( } start_time = time.time() # Start the timer cost_tracker = {"total_cost": 0, "cost_trace": []} # To store the cost details + + preprocessed_response = await preprocess_inputs( + image_url=image_url, caption=caption, text=text + ) + if not preprocessed_response.get("success"): + child_logger.error("Error in preprocessing inputs") + return { + "success": False, + "error": "Error in preprocessing inputs", + } + else: + child_logger.info("Preprocessing inputs successful") + screenshots_results = preprocessed_response.get("screenshots", []) + screenshots_content = get_gemini_content(screenshots_results) + results = preprocessed_response.get("result", {}) + is_access_blocked = results.get("isAccessBlocked", False) + is_video = results.get("isVideo", False) + intent = results.get("intent", "An error occurred, figure it out yourself") + if text is not None: child_logger.info(f"Generating text parts for text: {text}") parts = generate_text_parts(text) @@ -387,6 +407,11 @@ async def generate_note( elif image_url is not None: parts = generate_image_parts(image_url, caption) + parts.append(types.Part.from_text(f"User's likely intent: {intent}")) + + if screenshots_content: + parts.extend(screenshots_content) + report_dict = await self.generate_report(parts.copy()) duration = time.time() - start_time # Calculate duration @@ -399,6 +424,8 @@ async def generate_note( ) if summary_results.get("success"): report_dict["community_note"] = summary_results["community_note"] + report_dict["is_access_blocked"] = is_access_blocked + report_dict["is_video"] = is_video child_logger.info("Community note generated successfully") else: report_dict["success"] = False diff --git a/agents/openai_agent.py b/agents/openai_agent.py index 9482524..ceb1c30 100644 --- a/agents/openai_agent.py +++ b/agents/openai_agent.py @@ -4,7 +4,8 @@ import json from logger import StructuredLogger import time -from tools import summarise_report_factory +from tools import summarise_report_factory, preprocess_inputs +from tools.preprocess_inputs import get_openai_content import asyncio from openai.types.chat import ChatCompletionMessageToolCall from langfuse.decorators import observe @@ -389,6 +390,23 @@ async def generate_note( } start_time = time.time() # Start the timer cost_tracker = {"total_cost": 0, "cost_trace": []} # To store the cost details + preprocessed_response = await preprocess_inputs( + image_url=image_url, caption=caption, text=text + ) + if not preprocessed_response.get("success"): + child_logger.error("Error in preprocessing inputs") + return { + "success": False, + "error": "Error in preprocessing inputs", + } + else: + child_logger.info("Preprocessing inputs successful") + screenshots_results = preprocessed_response.get("screenshots", []) + screenshots_content = get_openai_content(screenshots_results) + results = preprocessed_response.get("result", {}) + is_access_blocked = results.get("isAccessBlocked", False) + is_video = results.get("isVideo", False) + intent = results.get("intent", "An error occurred, figure it out yourself") if text is not None: content = [ @@ -407,6 +425,16 @@ async def generate_note( {"type": "image_url", "image_url": {"url": image_url}}, ] + if screenshots_content: + content.extend(screenshots_content) + + content.append( + { + "type": "text", + "text": f"User's likely intent: {intent}", + } + ) + report_dict = await self.generate_report(content.copy()) duration = time.time() - start_time # Calculate duration @@ -417,6 +445,8 @@ async def generate_note( ) if summary_results.get("success"): report_dict["community_note"] = summary_results["community_note"] + report_dict["is_access_blocked"] = is_access_blocked + report_dict["is_video"] = is_video child_logger.info("Community note generated successfully") else: report_dict["success"] = False diff --git a/handlers/agent_generation.py b/handlers/agent_generation.py index 1fe7bcf..8e1f684 100644 --- a/handlers/agent_generation.py +++ b/handlers/agent_generation.py @@ -169,8 +169,8 @@ async def get_outputs( cn=chinese_note, links=outputs.get("sources", None), isControversial=outputs.get("isControversial", False), - isVideo=outputs.get("isVideo", False), - isAccessBlocked=outputs.get("isAccessBlocked", False), + isVideo=outputs.get("is_video", False), + isAccessBlocked=outputs.get("is_access_blocked", False), report=outputs.get("report", None), totalTimeTaken=outputs.get("total_time_taken", None), agentTrace=outputs.get("agent_trace", None), diff --git a/prompts/agent.py b/prompts/agent.py index d065141..a31acd9 100644 --- a/prompts/agent.py +++ b/prompts/agent.py @@ -6,11 +6,15 @@ Such content can be a text message or an image message. Image messages could, among others, be screenshots of their phone, pictures from their camera, or downloaded images. They could also be accompanied by captions. +In addition to what is submitted by the user, you will receive the following: +- screenshot of any webpages whose links are within the content, if the content submitted is a text +- the intent of the user, which you should craft your response to address + # Task Your task is to: -1. Infer the intent of whoever sent the message in - what exactly about the message they want checked, and how to go about it. Note the distinction between the sender and the author. For example, if the message contains claims but no source, they are probably interested in the factuality of the claims. If the message doesn't contain verifiable claims, they are probably asking whether it's from a legitimate, trustworthy source. If it's about an offer, they are probably enquiring about the legitimacy of the offer. If it's a message claiming it's from the government, they want to know if it is really from the government. -2. Use the supplied tools to help you check the information. Focus primarily on credibility/legitimacy of the source/author and factuality of information/claims, if relevant. If not, rely on contextual clues. When searching, give more weight to reliable, well-known sources. Use searches and visit sites judiciously, you only get 5 of each. -3. Submit a report to conclude your task. Start with your findings and end with a thoughtful conclusion. Be helpful and address the intent identified in the first step. + +1. Use the supplied tools to help you check the information. Focus primarily on credibility/legitimacy of the source/author and factuality of information/claims, if relevant. If not, rely on contextual clues. When searching, give more weight to reliable, well-known sources. Use searches and visit sites judiciously, you only get 5 of each. +2. Submit a report to conclude your task. Start with your findings and end with a thoughtful conclusion. Be helpful and address the intent identified in the first step. # Guidelines for Report: - Avoid references to the user, like "the user wants to know..." or the "the user sent in...", as these are obvious. diff --git a/prompts/preprocess_inputs.py b/prompts/preprocess_inputs.py new file mode 100644 index 0000000..02a7f0a --- /dev/null +++ b/prompts/preprocess_inputs.py @@ -0,0 +1,71 @@ +from langfuse import Langfuse + +review_report_system_prompt = """# Context + +You are an agent behind CheckMate, a product that allows users based in Singapore to send in dubious content they aren't sure whether to trust, and checks such content on their behalf. + +Such content can be a text message or an image message. Image messages could, among others, be screenshots of their phone, pictures from their camera, or downloaded images. They could also be accompanied by captions. + +# Task + +Given these inputs: +- content submitted by the user, which could be an image or a text +- screenshots of any webpages whose links within the content + +Your task is to: +1. Determine if the screenshots indicate that the content is a video, and/or access to the content is blocked. +2. Infer the intent of whoever sent the message in - what exactly about the message they want checked, and how to go about it. Note the distinction between the sender and the author. For example, if the message contains claims but no source, they are probably interested in the factuality of the claims. If the message doesn't contain verifiable claims, they are probably asking whether it's from a legitimate, trustworthy source. If it's about an offer, they are probably enquiring about the legitimacy of the offer. If it's a message claiming it's from the government, they want to know if it is really from the government.""" + +config = { + "model": "gpt-4o", + "temperature": 0.0, + "seed": 11, + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "summarise_report", + "schema": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": "The reasoning behind the intent you inferred from the message.", + }, + "is_access_blocked": { + "type": "boolean", + "description": "True if the content or URL sent by the user to be checked is inaccessible/removed/blocked. An example is being led to a login page instead of post content.", + }, + "is_video": { + "type": "boolean", + "description": "True if the content or URL sent by the user to be checked points to a video (e.g., YouTube, TikTok, Instagram Reels, Facebook videos).", + }, + "intent": { + "type": "string", + "description": "What the user's intent is, e.g. to check whether this is a scam, to check if this is really from the government, to check the facts in this article, etc.", + }, + }, + "required": ["is_access_blocked", "is_video", "reasoning", "intent"], + "additionalProperties": False, + }, + }, + }, +} + + +def compile_messages_array(): + prompt_messages = [{"role": "system", "content": review_report_system_prompt}] + return prompt_messages + + +if __name__ == "__main__": + langfuse = Langfuse() + prompt_messages = compile_messages_array() + langfuse.create_prompt( + name="preprocess_inputs", + type="chat", + prompt=prompt_messages, + labels=["production", "development", "uat"], # directly promote to production + config=config, # optionally, add configs (e.g. model parameters or model tools) or tags + ) + langfuse.get_prompt("preprocess_inputs", label="production") + print("Prompt created successfully.") diff --git a/requirements.txt b/requirements.txt index 94befcb..379eba8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,5 @@ google-genai==0.3.0 pytest==8.3.4 pytest-asyncio==0.25.1 responses==0.25.3 -google-cloud-firestore==2.20.0 \ No newline at end of file +google-cloud-firestore==2.20.0 +urlextract==1.9.0 \ No newline at end of file diff --git a/tests/tools/test_preprocess_inputs.py b/tests/tools/test_preprocess_inputs.py new file mode 100644 index 0000000..c425fee --- /dev/null +++ b/tests/tools/test_preprocess_inputs.py @@ -0,0 +1,113 @@ +import pytest +from unittest.mock import patch, AsyncMock +from tools import get_screenshots_from_text, preprocess_inputs + + +@pytest.mark.asyncio +async def test_get_screenshots_from_text_no_urls(): + text = "This is a text without any URLs" + result = await get_screenshots_from_text(text) + assert result == [] + + +@pytest.mark.asyncio +async def test_get_screenshots_from_text_single_url(): + with patch( + "tools.preprocess_inputs.get_website_screenshot", new_callable=AsyncMock + ) as mock_screenshot: + mock_screenshot.return_value = { + "success": True, + "result": "https://example.com/screenshot1.png", + } + + text = "Check out https://example.com" + result = await get_screenshots_from_text(text) + + assert len(result) == 1 + assert result[0]["url"] == "https://example.com" + assert result[0]["image_url"] == "https://example.com/screenshot1.png" + + mock_screenshot.assert_called_once_with("https://example.com") + + +@pytest.mark.asyncio +async def test_get_screenshots_from_text_multiple_urls(): + with patch( + "tools.preprocess_inputs.get_website_screenshot", new_callable=AsyncMock + ) as mock_screenshot: + mock_screenshot.side_effect = [ + {"success": True, "result": "https://example.com/screenshot1.png"}, + {"success": True, "result": "https://example.com/screenshot2.png"}, + ] + + text = "Check these: https://example1.com and https://example2.com" + result = await get_screenshots_from_text(text) + + assert len(result) == 2 + assert result[0]["url"] == "https://example1.com" + assert result[0]["image_url"] == "https://example.com/screenshot1.png" + assert result[1]["url"] == "https://example2.com" + assert result[1]["image_url"] == "https://example.com/screenshot2.png" + + assert mock_screenshot.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_screenshots_from_text_failed_screenshot(): + with patch( + "tools.preprocess_inputs.get_website_screenshot", new_callable=AsyncMock + ) as mock_screenshot: + mock_screenshot.return_value = { + "success": False, + "error": "Failed to get screenshot", + } + + text = "Check out https://example.com" + result = await get_screenshots_from_text(text) + + assert len(result) == 0 + mock_screenshot.assert_called_once_with("https://example.com") + + +@pytest.mark.asyncio +async def test_multiple_screenshots(): + text = """ + Check these links: + https://example.com + https://google.com + https://github.com + """ + result = await get_screenshots_from_text(text) + + assert len(result) > 0 + for screenshot in result: + assert "url" in screenshot + assert "image_url" in screenshot + assert isinstance(screenshot["url"], str) + assert isinstance(screenshot["image_url"], str) + assert screenshot["url"] in text + + +@pytest.mark.asyncio +async def test_preprocess_inputs(): + text = """ + Check these links: + https://example.com + https://google.com + https://github.com + """ + result = await preprocess_inputs(None, None, text) + + # We don't know exactly how many will succeed, but we should get some results + assert len(result) > 0 + + # Verify structure of results + assert "result" in result + assert "screenshots" in result + + result_json = result["result"] + + assert "is_access_blocked" in result_json + assert "is_video" in result_json + assert "reasoning" in result_json + assert "intent" in result_json diff --git a/tools/__init__.py b/tools/__init__.py index 29d5a1f..581ca86 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -13,6 +13,12 @@ plan_next_step, infer_intent, ) +from .preprocess_inputs import ( + get_screenshots_from_text, + preprocess_inputs, + get_gemini_content, + get_openai_content, +) __all__ = [ "get_screenshot_tool", @@ -31,4 +37,8 @@ "infer_intent", "translation_tool", "translate_text", + "get_screenshots_from_text", + "preprocess_inputs", + "get_gemini_content", + "get_openai_content", ] diff --git a/tools/preprocess_inputs.py b/tools/preprocess_inputs.py new file mode 100644 index 0000000..1294cf6 --- /dev/null +++ b/tools/preprocess_inputs.py @@ -0,0 +1,171 @@ +import os +from clients.openai import create_openai_client +from typing import Union +from langfuse.decorators import observe +import json +from logger import StructuredLogger +from tools import get_website_screenshot +from langfuse import Langfuse +from urlextract import URLExtract +import asyncio +from tools.website_screenshot import get_website_screenshot +from utils.gemini_utils import generate_screenshot_parts +from google.genai import types + +extractor = URLExtract() +langfuse = Langfuse() +client = create_openai_client("openai") +logger = StructuredLogger("preprocess_inputs") + + +async def get_screenshots_from_text(text: str) -> list: + """Extract URLs from text and get screenshots in parallel.""" + results = [] + urls = extractor.find_urls(text, only_unique=True, check_dns=True) + + if urls: + screenshot_tasks = [get_website_screenshot(url) for url in urls] + screenshot_responses = await asyncio.gather(*screenshot_tasks) + + for url, response in zip(urls, screenshot_responses): + if response.get("success") and "result" in response: + results.append({"url": url, "image_url": response["result"]}) + elif response.get("success") is False: + results.append({"url": url, "error": response.get("error")}) + return results + + +def get_openai_content(screenshot_results): + content = [] + for result in screenshot_results: + if "image_url" in result: + content.append( + {"type": "text", "text": f"Screenshot of {result['url']} below:"} + ) + content.append( + { + "type": "image_url", + "image_url": {"url": result["image_url"]}, + } + ) + else: + content.append( + { + "type": "text", + "text": f"Blocked from/failed at getting screenshot of {result['url']}: {result['error']}", + } + ) + return content + + +def get_gemini_content(screenshot_results): + parts = [] + for result in screenshot_results: + if "image_url" in result: + parts.extend(generate_screenshot_parts(result["image_url"], result["url"])) + else: + parts.append( + types.Part.from_text( + f"Blocked from/failed at getting screenshot of {result['url']}: {result['error']}", + ) + ) + return parts + + +@observe() +async def preprocess_inputs( + image_url: Union[str, None], caption: Union[str, None], text: Union[str, None] +): + try: + prompt = langfuse.get_prompt( + "preprocess_inputs", label=os.getenv("ENVIRONMENT") + ) + config = prompt.config + messages = prompt.compile() + content = [] + + if text: + content.append( + { + "type": "text", + "text": f"User sent in: {text}", + } + ) + screenshot_results = await get_screenshots_from_text(text) + screenshot_content = get_openai_content(screenshot_results) + content.extend(screenshot_content) + elif image_url: + caption_suffix = ( + "no caption" if caption is None else f"this caption: {caption}" + ) + content = [ + { + "type": "text", + "text": f"User sent in the following image with {caption_suffix}", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + ] + messages.append( + { + "role": "user", + "content": content, + } + ) + + response = client.chat.completions.create( + model=config.get("model", "gpt-4o"), + messages=messages, + temperature=config.get("temperature", 0), + seed=config.get("seed", 11), + response_format=config["response_format"], + langfuse_prompt=prompt, + ) + result = json.loads(response.choices[0].message.content) + return {"success": True, "result": result, "screenshots": screenshot_results} + except: + logger.error("Error in preprocess_inputs") + return {"success": False} + + +preprocess_inputs_definition = dict( + name="submit_report_for_review", + description="Submits a report, which concludes the task.", + parameters={ + "type": "OBJECT", + "properties": dict( + [ + ( + "image_url", + { + "type": ["STRING", "NULL"], + "description": "The URL of the image to be checked.", + }, + ), + ( + "caption", + { + "type": ["STRING", "NULL"], + "description": "The caption that accompanies the image to be checked", + }, + ), + ( + "text", + { + "type": ["STRING", "NULL"], + "description": "The text of the message to be checked.", + }, + ), + ] + ), + "required": [ + "image_url", + "caption", + "text", + ], + }, +) + +review_report_tool = { + "function": preprocess_inputs, + "definition": preprocess_inputs_definition, +} diff --git a/tools/review_report.py b/tools/review_report.py index 0c2843f..c82d0f8 100644 --- a/tools/review_report.py +++ b/tools/review_report.py @@ -15,9 +15,7 @@ @observe() -async def submit_report_for_review( - report, sources, isControversial, isVideo, isAccessBlocked -): +async def submit_report_for_review(report, sources, isControversial): prompt = langfuse.get_prompt("review_report", label=os.getenv("ENVIRONMENT")) config = prompt.config @@ -70,28 +68,12 @@ async def submit_report_for_review( "description": "True if the content contains political or religious viewpoints that are grounded in opinions rather than provable facts, and are likely to be divisive or polarizing.", }, ), - ( - "isVideo", - { - "type": "BOOLEAN", - "description": "True if the content or URL sent by the user to be checked points to a video (e.g., YouTube, TikTok, Instagram Reels, Facebook videos).", - }, - ), - ( - "isAccessBlocked", - { - "type": "BOOLEAN", - "description": "True if the content or URL sent by the user to be checked is inaccessible/removed/blocked. An example is being led to a login page instead of post content.", - }, - ), ] ), "required": [ "report", "sources", "isControversial", - "isVideo", - "isAccessBlocked", ], }, ) diff --git a/tools/summarise_report.py b/tools/summarise_report.py index 472ee69..989a596 100644 --- a/tools/summarise_report.py +++ b/tools/summarise_report.py @@ -1,4 +1,3 @@ -from google.genai import types import os from clients.openai import create_openai_client from typing import Union diff --git a/utils/gemini_utils.py b/utils/gemini_utils.py index c6a8c88..2d57118 100644 --- a/utils/gemini_utils.py +++ b/utils/gemini_utils.py @@ -61,6 +61,31 @@ def generate_image_parts(image_url: str, caption: str = None): return parts +def generate_screenshot_parts(image_url: str, url: str = None): + """Generates a list of parts for an image with an optional caption. + + Args: + image_url: The URL of the image. + caption: An optional caption for the image. + + Returns: + A list of parts containing the image and caption. + """ + parts = [] + if image_url is None: + raise ValueError("Image URL is required when data_type is 'image'") + if image_url.startswith("gs://"): + # parts.append(types.Part.from_uri(image_url, mime_type="image/jpeg")) #TODO: Change in future + parts.append(get_image_part(image_url)) + else: + image = httpx.get(image_url) + file_content = image.content + parts.append(types.Part.from_bytes(data=file_content, mime_type="image/jpeg")) + + parts.append(types.Part.from_text(f"Screenshot of {url} above")) + return parts + + def generate_text_parts(text: str): """Generates a list of parts for a text input. From 6d2d885a12e17d35b3d224cb9ca344c775eb7971 Mon Sep 17 00:00:00 2001 From: Bing Wen Tan Date: Sun, 16 Feb 2025 21:46:58 +0800 Subject: [PATCH 2/2] corrected errors --- agents/gemini_agent.py | 4 ++-- agents/openai_agent.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/agents/gemini_agent.py b/agents/gemini_agent.py index 3a245f5..e9158be 100644 --- a/agents/gemini_agent.py +++ b/agents/gemini_agent.py @@ -396,8 +396,8 @@ async def generate_note( screenshots_results = preprocessed_response.get("screenshots", []) screenshots_content = get_gemini_content(screenshots_results) results = preprocessed_response.get("result", {}) - is_access_blocked = results.get("isAccessBlocked", False) - is_video = results.get("isVideo", False) + is_access_blocked = results.get("is_access_blocked", False) + is_video = results.get("is_video", False) intent = results.get("intent", "An error occurred, figure it out yourself") if text is not None: diff --git a/agents/openai_agent.py b/agents/openai_agent.py index ceb1c30..f8b17fb 100644 --- a/agents/openai_agent.py +++ b/agents/openai_agent.py @@ -404,8 +404,8 @@ async def generate_note( screenshots_results = preprocessed_response.get("screenshots", []) screenshots_content = get_openai_content(screenshots_results) results = preprocessed_response.get("result", {}) - is_access_blocked = results.get("isAccessBlocked", False) - is_video = results.get("isVideo", False) + is_access_blocked = results.get("is_access_blocked", False) + is_video = results.get("is_video", False) intent = results.get("intent", "An error occurred, figure it out yourself") if text is not None: