Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mostlyai/sdk/_data/non_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
TOP_K = None
TOP_P = 0.95
QUOTA_PENALTY_FACTOR = 0.05
FK_MATCHING_PARENT_BATCH_SIZE = 5_000
FK_MATCHING_CHILD_BATCH_SIZE = 5_000


# Supported Encoding Types
FK_MODEL_ENCODING_TYPES = [
Expand Down Expand Up @@ -1131,7 +1134,7 @@ def initialize_remaining_capacity(

# Generate children counts using engine with parent data as seed
# The engine will predict the __CHILDREN_COUNT__ column based on parent features
_LOG.info(f"Generating cardinality predictions using engine for {len(parent_data)} parents")
_LOG.info(f"Generating cardinality predictions for {len(parent_data)} parents")

engine.generate(
seed_data=parent_data,
Expand Down
106 changes: 32 additions & 74 deletions mostlyai/sdk/_local/execution/step_finalize_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from mostlyai.sdk._data.file.table.csv import CsvDataTable
from mostlyai.sdk._data.file.table.parquet import ParquetDataTable
from mostlyai.sdk._data.non_context import (
FK_MATCHING_CHILD_BATCH_SIZE,
FK_MATCHING_PARENT_BATCH_SIZE,
add_context_parent_data,
assign_non_context_fks_randomly,
initialize_remaining_capacity,
Expand All @@ -42,10 +44,6 @@

_LOG = logging.getLogger(__name__)

# FK processing constants
FK_MIN_CHILDREN_BATCH_SIZE = 10
FK_PARENT_BATCH_SIZE = 1_000


def execute_step_finalize_generation(
*,
Expand Down Expand Up @@ -312,49 +310,15 @@ def process_table_with_random_fk_assignment(
write_batch_outputs(processed_data, table_name, chunk_idx, pqt_path, csv_path)


def calculate_optimal_child_batch_size_for_relation(
parent_key_count: int,
children_row_count: int,
parent_batch_size: int,
relation_name: str,
) -> int:
"""Calculate optimal child batch size for a specific FK relationship."""
num_parent_batches = max(1, math.ceil(parent_key_count / parent_batch_size))

# ideal batch size for full parent utilization
ideal_batch_size = children_row_count // num_parent_batches

# apply minimum batch size constraint
optimal_batch_size = max(ideal_batch_size, FK_MIN_CHILDREN_BATCH_SIZE)

# log utilization metrics
num_child_batches = children_row_count // optimal_batch_size
parent_utilization = min(num_child_batches / num_parent_batches * 100, 100)

_LOG.info(
f"[{relation_name}] Batch size optimization | "
f"total_children: {children_row_count} | "
f"parent_size: {parent_key_count} | "
f"parent_batch_size: {parent_batch_size} | "
f"parent_batches: {num_parent_batches} | "
f"ideal_child_batch: {ideal_batch_size} | "
f"optimal_child_batch: {optimal_batch_size} | "
f"parent_utilization: {parent_utilization:.1f}%"
)

return optimal_batch_size


def process_table_with_fk_models(
*,
table_name: str,
schema: Schema,
pqt_path: Path,
csv_path: Path | None,
parent_batch_size: int = FK_PARENT_BATCH_SIZE,
job_workspace_dir: Path,
) -> None:
"""Process table with ML model-based FK assignment using logical child batches."""
"""Process table with ML model-based FK assignment using fixed batch sizes."""

fk_models_workspace_dir = job_workspace_dir / "FKModelsStore" / table_name
non_ctx_relations = [rel for rel in schema.non_context_relations if rel.child.table == table_name]
Expand All @@ -374,21 +338,6 @@ def process_table_with_fk_models(
do_coerce_dtypes=True,
)

# Calculate optimal batch size for each relationship
relation_batch_sizes = {}
for relation in non_ctx_relations:
parent_table_name = relation.parent.table
parent_key_count = len(parent_keys_cache[parent_table_name])
relation_name = f"{relation.child.table}.{relation.child.column}->{parent_table_name}"

optimal_batch_size = calculate_optimal_child_batch_size_for_relation(
parent_key_count=parent_key_count,
children_row_count=children_table.row_count,
parent_batch_size=parent_batch_size,
relation_name=relation_name,
)
relation_batch_sizes[relation] = optimal_batch_size

# Initialize remaining capacity for all relations
# At this point, both FK models and cardinality models are guaranteed to exist
# (checked by are_fk_models_available)
Expand All @@ -401,8 +350,7 @@ def process_table_with_fk_models(
parent_keys_df = parent_keys_cache[parent_table_name]
parent_table = parent_tables[parent_table_name]

_LOG.info(f"Using Engine-based Cardinality Model for {relation.child.table}.{relation.child.column}")
# Use Engine-based Cardinality Model to predict capacities
_LOG.info(f"Using Cardinality Model for {relation.child.table}.{relation.child.column}")
parent_data = parent_table.read_data(
where={pk_col: parent_keys_df[pk_col].tolist()},
do_coerce_dtypes=True,
Expand All @@ -421,35 +369,45 @@ def process_table_with_fk_models(
parent_table_name = relation.parent.table
parent_table = parent_tables[parent_table_name]
parent_pk = relation.parent.column
optimal_batch_size = relation_batch_sizes[relation]
relation_name = f"{relation.child.table}.{relation.child.column}->{parent_table_name}"
parent_keys_df = parent_keys_cache[parent_table_name]
parent_size = len(parent_keys_df)
child_size = len(chunk_data)
parent_batch_size = min(FK_MATCHING_PARENT_BATCH_SIZE, parent_size)
child_batch_size = min(FK_MATCHING_CHILD_BATCH_SIZE, child_size)
num_child_batches = math.ceil(child_size / child_batch_size)
total_parent_samples_needed = num_child_batches * parent_batch_size

_LOG.info(f" Processing relationship {relation_name} with batch size {optimal_batch_size}")
_LOG.info(
f"Processing relationship {relation.child.table}.{relation.child.column}->{parent_table_name} with "
f"parent_batch_size={parent_batch_size}, child_batch_size={child_batch_size}, "
f"num_child_batches={num_child_batches}"
)

parent_keys_df = parent_keys_cache[parent_table_name]
# sample enough parent data to cover all child batches in this chunk
sampled_parent_keys = parent_keys_df.sample(
n=total_parent_samples_needed, replace=total_parent_samples_needed > parent_size
)[parent_pk].tolist()
parent_data_for_chunk = parent_table.read_data(
where={parent_pk: sampled_parent_keys},
columns=parent_table.columns,
do_coerce_dtypes=True,
)

processed_batches = []

for batch_start in range(0, len(chunk_data), optimal_batch_size):
batch_end = min(batch_start + optimal_batch_size, len(chunk_data))
for batch_idx, batch_start in enumerate(range(0, child_size, child_batch_size)):
batch_end = min(batch_start + child_batch_size, child_size)
batch_data = chunk_data.iloc[batch_start:batch_end].copy()

sampled_parent_keys = parent_keys_df.sample(
n=parent_batch_size, replace=len(parent_keys_df) < parent_batch_size
)[parent_pk].tolist()

parent_data = parent_table.read_data(
where={parent_pk: sampled_parent_keys},
columns=parent_table.columns,
do_coerce_dtypes=True,
)

batch_data = add_context_parent_data(
tgt_data=batch_data,
tgt_table=children_table,
schema=schema,
)

# slice the appropriate parent batch from the pre-fetched data
parent_slice_start = batch_idx * parent_batch_size
parent_slice_end = min(parent_slice_start + parent_batch_size, len(parent_data_for_chunk))
parent_data = parent_data_for_chunk.iloc[parent_slice_start:parent_slice_end]

assert relation in remaining_capacity
processed_batch = match_non_context(
fk_models_workspace_dir=fk_models_workspace_dir,
Expand Down