File tree Expand file tree Collapse file tree 2 files changed +7
-9
lines changed
Expand file tree Collapse file tree 2 files changed +7
-9
lines changed Original file line number Diff line number Diff line change 11import random
22from collections import deque
3- from typing import Literal , cast
4-
5- from torchtyping import TensorType
3+ from typing import Literal
64
75from ..config import ExperimentConfig
86from ..logger import logger
@@ -121,10 +119,9 @@ def train(
121119 selected_examples = []
122120 for quantile in selected_examples_quantiles :
123121 for example in quantile :
124- example .normalized_activations = cast (
125- TensorType ["seq" ],
126- (example .activations * 10 / max_activation ).floor (),
127- )
122+ example .normalized_activations = (
123+ example .activations * 10 / max_activation
124+ ).floor ()
128125 selected_examples .extend (quantile )
129126 return selected_examples
130127
Original file line number Diff line number Diff line change 11from typing import Any , Type , TypeVar , cast
22
3- from torchtyping import TensorType
3+ from jaxtyping import Float
4+ from torch import Tensor
45from transformers import AutoTokenizer , PreTrainedTokenizer , PreTrainedTokenizerFast
56
67
@@ -26,7 +27,7 @@ def load_tokenized_data(
2627 )
2728 tokens_ds = tokens_ds .shuffle (seed )
2829
29- tokens = cast (TensorType [ "batch" , "seq" ], tokens_ds ["input_ids" ])
30+ tokens = cast (Float [ Tensor , "batch seq" ], tokens_ds ["input_ids" ])
3031
3132 return tokens
3233
You can’t perform that action at this time.
0 commit comments