Skip to content

Commit 05a0c27

Browse files
authored
Improve Model Loading Logic for TransformerDecoderModel (#290)
1 parent cd6c14e commit 05a0c27

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

convokit/forecaster/TransformerDecoderModel.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import shutil
2020

2121

22-
def _get_templet_map(model_name_or_path):
22+
def _get_template_map(model_name_or_path):
2323
"""
2424
Map a model name or path to its corresponding prompt template family.
2525
@@ -28,12 +28,12 @@ def _get_templet_map(model_name_or_path):
2828
:raises ValueError: If the model is not recognized.
2929
"""
3030
TEMPLATE_PATTERNS = [
31-
("google/gemma-2-", "gemma2"),
32-
("google/gemma-3-", "gemma3"),
33-
("mistralai/mistral", "mistral"),
34-
("HuggingFaceH4/zephyr", "zephyr"),
35-
("microsoft/phi-4", "phi-4"),
36-
("meta-llama/Llama-3", "llama3"),
31+
("gemma-2", "gemma2"),
32+
("gemma-3", "gemma3"),
33+
("mistral", "mistral"),
34+
("zephyr", "zephyr"),
35+
("phi-4", "phi-4"),
36+
("llama-3", "llama3"),
3737
]
3838

3939
for pattern, template in TEMPLATE_PATTERNS:
@@ -84,7 +84,7 @@ def __init__(
8484

8585
self.tokenizer = get_chat_template(
8686
tokenizer,
87-
chat_template=_get_templet_map(model_name_or_path), # TO-DO: Define this
87+
chat_template=_get_template_map(self.model.config.name_or_path),
8888
mapping={"role": "from", "content": "value", "user": "human", "assistant": "model"},
8989
)
9090
# Custom prompt

0 commit comments

Comments
 (0)