Skip to content

change HPU warmup logic: seq length should be with exponential growth #3217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions backends/gaudi/server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2))
MAX_BATCH_SIZE = (
int(os.environ.get("MAX_BATCH_SIZE"))
if os.environ.get("MAX_BATCH_SIZE") is not None
Expand All @@ -71,8 +72,21 @@ def torch_compile_for_eager(func):
)


def round_up_seq(number, k):
return (number + k - 1) // k * k
def round_up_seq(number, k, base):
exponent = math.ceil(math.log(number / k, base))
return k * (base**exponent)


def iterate_powers_of_base(max_value, start, base):
current = start
result = []
assert (
max_value >= start
), f"max_value {max_value} must be greater than start {start}"
while current < max_value:
result.append(current)
current *= base
return result


def round_up_batch(number):
Expand Down Expand Up @@ -575,7 +589,9 @@ def from_pb(
assert (
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
rounded_seq_len = round_up_seq(
input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
else:
Expand Down Expand Up @@ -1345,14 +1361,9 @@ def warmup(
max_exp + 1,
)
]
prefill_seqlen_list = [
seq
for seq in range(
PAD_SEQUENCE_TO_MULTIPLE_OF,
max_input_tokens,
PAD_SEQUENCE_TO_MULTIPLE_OF,
)
]
prefill_seqlen_list = iterate_powers_of_base(
max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
)
prefill_seqlen_list.append(max_input_tokens)
prefill_batch_size_list.sort(reverse=True)
prefill_seqlen_list.sort(reverse=True)
Expand Down
Loading