Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions gemma/gm/text/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,15 +407,28 @@ def _tokenize_prompts(
prompt: str | Sequence[str],
*,
add_bos: bool,
pad_length: int | None = None,
pad_length: int | tuple[int, ...] | None = None,
) -> Float['B L']:
"""Encode the prompts."""
prompt = _normalize_prompt(prompt)
tokens = [self.tokenizer.encode(p, add_bos=add_bos) for p in prompt]

# Notice that if pad_length exceeds the maximum length of the prompts,
# an error will be raised by the `.pad` function below.
max_prompt_len = pad_length or max(len(t) for t in tokens)
# Calculate the maximum prompt length, handling pad_length buckets.
actual_max = max(len(t) for t in tokens)
if pad_length is None:
max_prompt_len = actual_max
elif isinstance(pad_length, tuple):
# Handle tuple buckets - pick smallest bucket that fits.
for bucket_size in pad_length:
if actual_max <= bucket_size:
max_prompt_len = bucket_size
break
else:
# No bucket fits, use actual max.
max_prompt_len = actual_max
else:
# pad_length is an int.
max_prompt_len = pad_length
# In multi-host, each host read different data, so sync to the max length
# across all hosts.
max_prompt_len = _max_across_hosts(max_prompt_len)
Expand Down