Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions simplexity/generative_processes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,14 +807,9 @@ def _build_prefix_vocab_maps(n_components: int, v: int, n_shared: int, n_unique:

C0 gets [0..V-1]. Ci>0 gets shared [0..n_shared-1] + unique tokens above V.
"""
vocab_maps: list[list[int]] = []
for i in range(n_components):
if i == 0:
vocab_maps.append(list(range(v)))
else:
unique_start = v + (i - 1) * n_unique
vocab_maps.append(list(range(n_shared)) + list(range(unique_start, unique_start + n_unique)))
return vocab_maps
return [list(range(v))] + [
list(range(n_shared)) + list(range(v + i * n_unique, v + (i + 1) * n_unique)) for i in range(n_components - 1)
]


def _build_sliding_vocab_maps(n_components: int, v: int, n_unique: int) -> list[list[int]]:
Expand All @@ -826,17 +821,15 @@ def _build_sliding_vocab_maps(n_components: int, v: int, n_unique: int) -> list[
return [list(range(i * offset, i * offset + v)) for i in range(n_components)]


def _build_random_vocab_maps(n_components: int, v: int, n_shared: int, n_unique: int, seed: int) -> list[list[int]]:
def _build_random_vocab_maps(n_components: int, v: int, n_unique: int, seed: int) -> list[list[int]]:
"""Build vocab maps by having each component randomly sample V tokens from the global pool.

The global vocab size is the same as in prefix mode (V + (n_components - 1) * n_unique),
and each component independently samples V tokens without replacement.
The global vocab size is the same as in prefix mode:
V + (n_components - 1) * n_unique.
"""
prefix_maps = _build_prefix_vocab_maps(n_components, v, n_shared, n_unique)
global_vocab_size = max(max(vm) for vm in prefix_maps) + 1
global_vocab_size = v + (n_components - 1) * n_unique
rng = random.Random(seed)
global_tokens = list(range(global_vocab_size))
return [sorted(rng.sample(global_tokens, v)) for _ in range(n_components)]
return [sorted(rng.sample(range(global_vocab_size), v)) for _ in range(n_components)]


def build_nonergodic_partial_overlap(
Expand Down Expand Up @@ -884,8 +877,9 @@ def build_nonergodic_partial_overlap(
elif mode == "sliding":
vocab_maps = _build_sliding_vocab_maps(n_components, v, n_unique)
elif mode == "random":
assert seed is not None
vocab_maps = _build_random_vocab_maps(n_components, v, n_shared, n_unique, seed)
if seed is None:
raise ValueError("seed is required when mode='random'")
vocab_maps = _build_random_vocab_maps(n_components, v, n_unique, seed)
else:
raise ValueError(f"Unknown mode '{mode}'. Must be 'prefix', 'sliding', or 'random'.")

Expand Down
Loading
Loading