diff --git a/mistral_example.py b/mistral_example.py index ae46e75..004777f 100644 --- a/mistral_example.py +++ b/mistral_example.py @@ -9,6 +9,7 @@ from transformers.models.mistral.modeling_mistral import MistralAttention from transformers import AutoTokenizer, AutoModelForCausalLM import transformers +import torch original_mistral_forward = MistralAttention.forward self_extend_forward = partial(MistralSE.self_extend_forward, group_size_1=4, group_size_2=1024) @@ -28,7 +29,10 @@ example = json.loads(line) prompt_postfix = "What is the pass key? The pass key is " prompt = example["input"] + prompt_postfix - input_ids = tokenizer(prompt, return_tensors="pt").input_ids + if torch.backends.mps.is_available(): + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("mps") + else: + input_ids = tokenizer(prompt, return_tensors="pt").input_ids print( "-----------------------------------" ) print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " ) print( "Passkey target:", example["target"] ) @@ -47,6 +51,8 @@ answer = answer.replace("\n", "\\n") print( answer ) print( "-----------------------------------\n" ) + if torch.backends.mps.is_available(): + torch.mps.empty_cache()