Skip to content
Merged
97 changes: 56 additions & 41 deletions knowledge_storm/collaborative_storm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from .modules.expert_generation import GenerateExpertModule
from .modules.warmstart_hierarchical_chat import WarmStartModule
from ..dataclass import ConversationTurn, KnowledgeBase
from ..encoder import Encoder
from ..interface import LMConfigs, Agent
from ..logging_wrapper import LoggingWrapper
from ..lm import OpenAIModel, AzureOpenAIModel, TogetherClient
from ..lm import LitellmModel
from ..rm import BingSearch


Expand Down Expand Up @@ -45,27 +46,26 @@ def init(
if lm_type and lm_type == "openai":
openai_kwargs = {
"api_key": os.getenv("OPENAI_API_KEY"),
"api_provider": "openai",
"temperature": temperature,
"top_p": top_p,
"api_base": None,
}
self.question_answering_lm = OpenAIModel(
self.question_answering_lm = LitellmModel(
model="gpt-4o-2024-05-13", max_tokens=1000, **openai_kwargs
)
self.discourse_manage_lm = OpenAIModel(
self.discourse_manage_lm = LitellmModel(
model="gpt-4o-2024-05-13", max_tokens=500, **openai_kwargs
)
self.utterance_polishing_lm = OpenAIModel(
self.utterance_polishing_lm = LitellmModel(
model="gpt-4o-2024-05-13", max_tokens=2000, **openai_kwargs
)
self.warmstart_outline_gen_lm = OpenAIModel(
self.warmstart_outline_gen_lm = LitellmModel(
model="gpt-4-1106-preview", max_tokens=500, **openai_kwargs
)
self.question_asking_lm = OpenAIModel(
self.question_asking_lm = LitellmModel(
model="gpt-4o-2024-05-13", max_tokens=300, **openai_kwargs
)
self.knowledge_base_lm = OpenAIModel(
self.knowledge_base_lm = LitellmModel(
model="gpt-4o-2024-05-13", max_tokens=1000, **openai_kwargs
)
elif lm_type and lm_type == "azure":
Expand All @@ -76,62 +76,62 @@ def init(
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
}
self.question_answering_lm = AzureOpenAIModel(
model="gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat"
self.question_answering_lm = LitellmModel(
model="azure/gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat"
)
self.discourse_manage_lm = AzureOpenAIModel(
model="gpt-4o", max_tokens=500, **azure_kwargs, model_type="chat"
self.discourse_manage_lm = LitellmModel(
model="azure/gpt-4o", max_tokens=500, **azure_kwargs, model_type="chat"
)
self.utterance_polishing_lm = AzureOpenAIModel(
model="gpt-4o", max_tokens=2000, **azure_kwargs, model_type="chat"
self.utterance_polishing_lm = LitellmModel(
model="azure/gpt-4o", max_tokens=2000, **azure_kwargs, model_type="chat"
)
self.warmstart_outline_gen_lm = AzureOpenAIModel(
model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat"
self.warmstart_outline_gen_lm = LitellmModel(
model="azure/gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat"
)
self.question_asking_lm = AzureOpenAIModel(
model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat"
self.question_asking_lm = LitellmModel(
model="azure/gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat"
)
self.knowledge_base_lm = AzureOpenAIModel(
model="gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat"
self.knowledge_base_lm = LitellmModel(
model="azure/gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat"
)
elif lm_type and lm_type == "together":
together_kwargs = {
"api_key": os.getenv("TOGETHER_API_KEY"),
"temperature": temperature,
"top_p": top_p,
}
self.question_answering_lm = TogetherClient(
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
self.question_answering_lm = LitellmModel(
model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_tokens=1000,
model_type="chat",
**together_kwargs,
)
self.discourse_manage_lm = TogetherClient(
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
self.discourse_manage_lm = LitellmModel(
model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_tokens=500,
model_type="chat",
**together_kwargs,
)
self.utterance_polishing_lm = TogetherClient(
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
self.utterance_polishing_lm = LitellmModel(
model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_tokens=2000,
model_type="chat",
**together_kwargs,
)
self.warmstart_outline_gen_lm = TogetherClient(
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
self.warmstart_outline_gen_lm = LitellmModel(
model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_tokens=500,
model_type="chat",
**together_kwargs,
)
self.question_asking_lm = TogetherClient(
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
self.question_asking_lm = LitellmModel(
model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_tokens=300,
model_type="chat",
**together_kwargs,
)
self.knowledge_base_lm = TogetherClient(
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
self.knowledge_base_lm = LitellmModel(
model="together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_tokens=1000,
model_type="chat",
**together_kwargs,
Expand Down Expand Up @@ -323,6 +323,7 @@ def __init__(
lm_config: CollaborativeStormLMConfigs,
runner_argument: RunnerArgument,
rm: dspy.Retrieve,
encoder: Encoder,
callback_handler: BaseCallbackHandler,
):
# parameter management
Expand All @@ -331,6 +332,7 @@ def __init__(
self.logging_wrapper = logging_wrapper
self.callback_handler = callback_handler
self.rm = rm
self.encoder = encoder
# role management
self.experts: List[CoStormExpert] = []
self.simulated_user: SimulatedUser = SimulatedUser(
Expand Down Expand Up @@ -360,6 +362,7 @@ def __init__(
lm_config=self.lm_config,
runner_argument=self.runner_argument,
logging_wrapper=self.logging_wrapper,
encoder=self.encoder,
callback_handler=self.callback_handler,
)
self.general_knowledge_provider = CoStormExpert(
Expand Down Expand Up @@ -469,16 +472,16 @@ def get_next_turn_policy(
elif self.runner_argument.rag_only_baseline_mode:
assert self.conversation_history[-1].role == "Guest"
next_turn_policy.agent = self.pure_rag_agent
elif self.next_turn_moderator_override:
next_turn_policy.agent = self.moderator
if not dry_run:
self.next_turn_moderator_override = False
elif (
not self.runner_argument.disable_moderator
and self._should_generate_question(conversation_history)
):
next_turn_policy.agent = self.moderator
next_turn_policy.should_reorganize_knowledge_base = True
elif self.next_turn_moderator_override:
next_turn_policy.agent = self.moderator
if not dry_run:
self.next_turn_moderator_override = False
# experts RAG gen
else:
next_turn_policy.agent = self.general_knowledge_provider
Expand Down Expand Up @@ -516,18 +519,21 @@ def __init__(
self.rm = BingSearch(k=runner_argument.retrieve_top_k)
else:
self.rm = rm
self.encoder = Encoder()
self.conversation_history = []
self.warmstart_conv_archive = []
self.knowledge_base = KnowledgeBase(
topic=self.runner_argument.topic,
knowledge_base_lm=self.lm_config.knowledge_base_lm,
node_expansion_trigger_count=self.runner_argument.node_expansion_trigger_count,
encoder=self.encoder,
)
self.discourse_manager = DiscourseManager(
lm_config=self.lm_config,
runner_argument=self.runner_argument,
logging_wrapper=self.logging_wrapper,
rm=self.rm,
encoder=self.encoder,
callback_handler=callback_handler,
)

Expand All @@ -546,15 +552,17 @@ def to_dict(self):
}

@classmethod
def from_dict(cls, data):
def from_dict(cls, data, callback_handler: BaseCallbackHandler = None):
# FIXME: does not use the lm_config data but naively use default setting
lm_config = CollaborativeStormLMConfigs()
lm_config.init(lm_type=os.getenv("OPENAI_API_TYPE"))
costorm_runner = cls(
lm_config=lm_config,
runner_argument=RunnerArgument.from_dict(data["runner_argument"]),
logging_wrapper=LoggingWrapper(lm_config),
callback_handler=callback_handler,
)
costorm_runner.encoder = Encoder()
costorm_runner.conversation_history = [
ConversationTurn.from_dict(turn) for turn in data["conversation_history"]
]
Expand All @@ -567,6 +575,7 @@ def from_dict(cls, data):
data=data["knowledge_base"],
knowledge_base_lm=costorm_runner.lm_config.knowledge_base_lm,
node_expansion_trigger_count=costorm_runner.runner_argument.node_expansion_trigger_count,
encoder=costorm_runner.encoder,
)
return costorm_runner

Expand Down Expand Up @@ -607,11 +616,15 @@ def warm_start(self):
warmstart_revised_conv if warmstart_revised_conv else warmstart_conv
)
self.warmstart_conv_archive = warmstart_conv
self.knowledge_base.reorganize()
self.knowledge_base.reogranize()
else:

if self.knowledge_base is None:
self.knowledge_base = KnowledgeBase(
topic=self.runner_argument.topic
topic=self.runner_argument.topic,
knowledge_base_lm=self.lm_config.knowledge_base_lm,
node_expansion_trigger_count=self.runner_argument.node_expansion_trigger_count,
encoder=self.encoder,
)
if self.conversation_history is None:
self.conversation_history = []
Expand All @@ -633,7 +646,9 @@ def generate_report(self) -> str:
Returns:
str: A string representing the report, with "#" "##" indicating hierarchical sections and [1][2] indicating references.
"""
with self.logging_wrapper.log_pipeline_stage("report generation stage"):
with self.logging_wrapper.log_pipeline_stage(
f"report generation after conv turn: {len(self.conversation_history)}"
):
with self.logging_wrapper.log_event(
"report generation stage: generate report"
):
Expand Down Expand Up @@ -741,5 +756,5 @@ def step(
):
if self.callback_handler is not None:
self.callback_handler.on_mindmap_reorg_start()
self.knowledge_base.reorganize()
self.knowledge_base.reogranize()
return conv_turn
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _get_cited_information_string(
self,
all_citation_index: Set[int],
knowledge_base: KnowledgeBase,
max_words: int = 1500,
max_words: int = 4000,
):
information = []
cur_word_count = 0
Expand Down
24 changes: 9 additions & 15 deletions knowledge_storm/collaborative_storm/modules/co_storm_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .grounded_question_generation import GroundedQuestionGenerationModule
from .simulate_user import GenSimulatedUserUtterance
from ...dataclass import ConversationTurn, KnowledgeBase
from ...encoder import get_text_embeddings
from ...encoder import Encoder
from ...interface import Agent, Information, LMConfigs
from ...logging_wrapper import LoggingWrapper

Expand Down Expand Up @@ -174,6 +174,7 @@ def __init__(
lm_config: LMConfigs,
runner_argument: "RunnerArgument",
logging_wrapper: LoggingWrapper,
encoder: Encoder,
callback_handler: BaseCallbackHandler = None,
):
super().__init__(topic, role_name, role_description)
Expand All @@ -184,6 +185,7 @@ def __init__(
engine=self.lm_config.question_asking_lm
)
self.callback_handler = callback_handler
self.encoder = encoder

def _get_conv_turn_unused_information(
self, conv_turn: ConversationTurn, knowledge_base: KnowledgeBase
Expand Down Expand Up @@ -211,19 +213,12 @@ def _get_conv_turn_unused_information(
# extract snippets to get embeddings
unused_information_snippets = [info.snippets[0] for info in unused_information]
# get embeddings
cache = knowledge_base.embedding_cache
unused_snippets_embeddings, _ = get_text_embeddings(
unused_information_snippets, embedding_cache=cache, max_workers=100
)
claim_embedding, _ = get_text_embeddings(
conv_turn.claim_to_make, embedding_cache=cache
)
query_embedding, _ = get_text_embeddings(
conv_turn.queries, embedding_cache=cache
)
cited_snippets_embedding, _ = get_text_embeddings(
cited_snippets, embedding_cache=cache
unused_snippets_embeddings = self.encoder.encode(
unused_information_snippets, max_workers=20
)
claim_embedding = self.encoder.encode(conv_turn.claim_to_make)
query_embedding = self.encoder.encode(conv_turn.queries)
cited_snippets_embedding = self.encoder.encode(cited_snippets)
# calculate similarity
query_similarities = cosine_similarity(
unused_snippets_embeddings, query_embedding
Expand Down Expand Up @@ -270,8 +265,7 @@ def _get_sorted_unused_snippets(
)
batch_snippets.append(conv_turn.claim_to_make)
batch_snippets.extend(conv_turn.queries)
cache = knowledge_base.embedding_cache
get_text_embeddings(batch_snippets, embedding_cache=cache, max_workers=300)
self.encoder.encode(batch_snippets, max_workers=20)

# get sorted unused snippets for each turn
sorted_snippets = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .collaborative_storm_utils import trim_output_after_hint
from ...dataclass import KnowledgeNode, KnowledgeBase
from ...encoder import get_text_embeddings
from ...encoder import Encoder
from ...interface import Information


Expand Down Expand Up @@ -51,8 +51,9 @@ class InsertInformationCandidateChoice(dspy.Signature):


class InsertInformationModule(dspy.Module):
def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], encoder: Encoder):
self.engine = engine
self.encoder = encoder
self.insert_info = dspy.ChainOfThought(InsertInformation)
self.candidate_choosing = dspy.Predict(InsertInformationCandidateChoice)

Expand Down Expand Up @@ -153,7 +154,7 @@ def _get_sorted_embed_sim_section(
query: str,
):
if encoded_outline is not None and encoded_outline.size > 0:
encoded_query, token_usage = get_text_embeddings(f"{question}, {query}")
encoded_query = self.encoder.encode(f"{question}, {query}")
sim = cosine_similarity([encoded_query], encoded_outline)[0]
sorted_indices = np.argsort(sim)
sorted_outlines = np.array(outlines)[sorted_indices[::-1]]
Expand Down Expand Up @@ -226,7 +227,6 @@ def forward(
insert_root: Optional[KnowledgeNode] = None,
skip_candidate_from_embedding: bool = False,
):

if not isinstance(information, List):
information = [information]
intent_to_placement_dict: Dict = self._info_list_to_intent_mapping(
Expand Down
Loading
Loading