Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions finetune/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@

import safetensors.torch
import torch
from mistral_common.tokens.tokenizers.sentencepiece import (
InstructTokenizerBase,
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerBase
from torch.distributed import barrier
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

Expand Down
2 changes: 1 addition & 1 deletion finetune/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Iterator, List, Optional

import numpy as np
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerBase

from .args import DataArgs
from .dataset import build_dataset
Expand Down
2 changes: 1 addition & 1 deletion finetune/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
FinetuningAssistantMessage,
SystemMessage,
)
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerBase

from finetune.distributed import get_rank

Expand Down
6 changes: 3 additions & 3 deletions finetune/data/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from mistral_common.tokens.instruct.request import InstructRequest
from mistral_common.tokens.tokenizers.base import Tokenizer
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerBase

from .exceptions import (
ConversationFormatError,
Expand Down Expand Up @@ -177,7 +177,7 @@ def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample:

# validate created messages
validator = MistralRequestValidatorV3(ValidationMode.finetuning)
validator.validate_messages(messages)
validator.validate_messages(messages, False)
validator._validate_tools(available_tools or [])

# whether to train only on last assistant message
Expand Down Expand Up @@ -328,7 +328,7 @@ def tokenize_instruct(
message = maybe_remove_call_id(message, is_last_message=is_last_message)

curr_tokens = instruct_tokenizer.encode_assistant_message(
message, is_before_last_user_message=False
message, is_before_last_user_message=False, continue_message=False
)

is_weighted = message.weight is None or message.weight == 1
Expand Down