diff --git a/bugbug/tools/code_review.py b/bugbug/tools/code_review.py index 4629e996f8..55242264f8 100644 --- a/bugbug/tools/code_review.py +++ b/bugbug/tools/code_review.py @@ -15,7 +15,7 @@ from typing import Iterable, Literal, Optional import tenacity -from langchain.chains import ConversationChain, LLMChain +from langchain.chains import ConversationChain, RunnableSequence from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate from langchain_openai import OpenAIEmbeddings @@ -1043,44 +1043,54 @@ def __init__( else "" ) - self.summarization_chain = LLMChain( - prompt=PromptTemplate.from_template( - PROMPT_TEMPLATE_SUMMARIZATION, - partial_variables={ - "experience_scope": ( - f"the {self.target_software} source code" - if self.target_software - else "a software project" - ) - }, - ), - llm=self.llm, - verbose=verbose, + summarization_chain_prompt = PromptTemplate.from_template( + PROMPT_TEMPLATE_SUMMARIZATION, + partial_variables={ + "experience_scope": ( + f"the {self.target_software} source code" + if self.target_software + else "a software project" + ) + }, ) - self.filtering_chain = LLMChain( - prompt=PromptTemplate.from_template( - PROMPT_TEMPLATE_FILTERING_ANALYSIS, - partial_variables={ - "target_code_consistency": self.target_software or "rest of the" - }, - ), - llm=self.llm, - verbose=verbose, + + self.summarization_chain = RunnableSequence( + [summarization_chain_prompt, self.llm], verbose=verbose + ) + + filtering_chain_prompt = PromptTemplate.from_template( + PROMPT_TEMPLATE_FILTERING_ANALYSIS, + partial_variables={ + "target_code_consistency": self.target_software or "rest of the" + }, + ) + + self.filtering_chain = RunnableSequence( + [filtering_chain_prompt, self.llm], verbose=verbose + ) + + deduplicating_chain_prompt = PromptTemplate.from_template( + PROMPT_TEMPLATE_DEDUPLICATE + ) + + self.deduplicating_chain = RunnableSequence( + [deduplicating_chain_prompt, self.llm], verbose=verbose ) - self.deduplicating_chain = LLMChain( - prompt=PromptTemplate.from_template(PROMPT_TEMPLATE_DEDUPLICATE), - llm=self.llm, - verbose=verbose, + + further_context_prompt = PromptTemplate.from_template( + PROMPT_TEMPLATE_FURTHER_CONTEXT_LINES ) - self.further_context_chain = LLMChain( - prompt=PromptTemplate.from_template(PROMPT_TEMPLATE_FURTHER_CONTEXT_LINES), - llm=self.llm, - verbose=verbose, + + self.further_context_chain = RunnableSequence( + [further_context_prompt, self.llm], verbose=verbose + ) + + further_info_chain_prompt = PromptTemplate.from_template( + PROMPT_TEMPLATE_FURTHER_INFO ) - self.further_info_chain = LLMChain( - prompt=PromptTemplate.from_template(PROMPT_TEMPLATE_FURTHER_INFO), - llm=self.llm, - verbose=verbose, + + self.further_info_chain = RunnableSequence( + [further_info_chain_prompt, self.llm], verbose=verbose ) self.function_search = function_search @@ -1107,17 +1117,17 @@ def run(self, patch: Patch) -> list[InlineComment] | None: return None output_summarization = self.summarization_chain.invoke( - {"patch": formatted_patch}, - return_only_outputs=True, - )["text"] + {"patch": formatted_patch} + ) if self.verbose: GenerativeModelTool._print_answer(output_summarization) if self.function_search is not None: - line_code_list = self.further_context_chain.run( - patch=formatted_patch, summarization=output_summarization - ).split("\n") + output_further_context = self.further_context_chain.invoke( + {"patch": formatted_patch, "summarization": output_summarization} + ) + line_code_list = output_further_context.split("\n") if self.verbose: GenerativeModelTool._print_answer(line_code_list) @@ -1129,11 +1139,12 @@ def run(self, patch: Patch) -> list[InlineComment] | None: formatted_patch, ) + output_further_info = self.further_info_chain.invoke( + {"patch": formatted_patch, "summarization": output_summarization} + ) function_list = [ function_name.strip() - for function_name in self.further_info_chain.run( - patch=formatted_patch, summarization=output_summarization - ).split("\n") + for function_name in output_further_info.split("\n") ] if self.verbose: @@ -1221,10 +1232,7 @@ def run(self, patch: Patch) -> list[InlineComment] | None: memory.clear() if len(self.comment_gen_llms) > 1: - output = self.deduplicating_chain.invoke( - {"review": output}, - return_only_outputs=True, - )["text"] + output = self.deduplicating_chain.invoke({"review": output}) if self.verbose: GenerativeModelTool._print_answer(output) @@ -1245,9 +1253,8 @@ def run(self, patch: Patch) -> list[InlineComment] | None: "review": output, "patch": patch.raw_diff, "rejected_examples": rejected_examples, - }, - return_only_outputs=True, - )["text"] + } + ) if self.verbose: GenerativeModelTool._print_answer(raw_output)