Skip to content

Commit 3f3f2f4

Browse files
author
SrGonao
committed
Merge branch 'fix_types' of https://github.com/EleutherAI/delphi into neighbour_latents
2 parents f001b41 + 5b0ea4e commit 3f3f2f4

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

delphi/latents/samplers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import random
22
from collections import deque
3-
from typing import Literal, cast
4-
5-
from torchtyping import TensorType
3+
from typing import Literal
64

75
from ..config import ExperimentConfig
86
from ..logger import logger
@@ -121,10 +119,9 @@ def train(
121119
selected_examples = []
122120
for quantile in selected_examples_quantiles:
123121
for example in quantile:
124-
example.normalized_activations = cast(
125-
TensorType["seq"],
126-
(example.activations * 10 / max_activation).floor(),
127-
)
122+
example.normalized_activations = (
123+
example.activations * 10 / max_activation
124+
).floor()
128125
selected_examples.extend(quantile)
129126
return selected_examples
130127

delphi/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Type, TypeVar, cast
22

3-
from torchtyping import TensorType
3+
from jaxtyping import Float
4+
from torch import Tensor
45
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
56

67

@@ -26,7 +27,7 @@ def load_tokenized_data(
2627
)
2728
tokens_ds = tokens_ds.shuffle(seed)
2829

29-
tokens = cast(TensorType["batch", "seq"], tokens_ds["input_ids"])
30+
tokens = cast(Float[Tensor, "batch seq"], tokens_ds["input_ids"])
3031

3132
return tokens
3233

0 commit comments

Comments
 (0)