Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.task import Task
from keras_hub.src.samplers.serialization import get as get_sampler
from keras.src.distribution import distribution_lib

try:
import tensorflow as tf
Expand Down Expand Up @@ -354,6 +355,16 @@ def preprocess(x):
x, sequence_length=max_length
)

def shard(x):
distribution = distribution_lib.distribution()
if distribution is None:
return x
result = {}
for key, value in x.items():
layout = distribution.get_data_layout(value.shape)
result[key] = distribution_lib.distribute_tensor(value, layout) if layout else value
return result

def generate(x):
return generate_function(x, stop_token_ids=stop_token_ids)

Expand Down Expand Up @@ -394,6 +405,8 @@ def postprocess(x):
if self.preprocessor is not None:
inputs = [preprocess(x) for x in inputs]

inputs = [shard(x) for x in inputs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change introduces important logic for distributed execution, but it lacks corresponding unit tests. The contribution guidelines are very clear on the requirement for testing for all new logic.

Since testing distributed code can be complex, you can add a unit test that verifies the sharding logic is correctly invoked by using mocks. Here's a possible approach:

  1. In your test case, use unittest.mock.patch to mock keras.src.distribution.distribution_lib.distribution.
  2. Make the mock return a mock Distribution object.
  3. This mock distribution object should have mock methods for get_data_layout and distribute_tensor.
  4. Call the generate() method on a CausalLM instance.
  5. Assert that distribution_lib.distribution() was called, and that get_data_layout and distribute_tensor were called on your mock object for the input tensors.

This will ensure that your sharding code path is exercised and correctly integrated, even without a real distributed environment.

Here is a conceptual example of what the test could look like:

from unittest import mock

# Inside a test method for CausalLM
mock_dist_layout = mock.Mock()
mock_dist = mock.Mock()
mock_dist.get_data_layout.return_value = mock_dist_layout
mock_dist.distribute_tensor.return_value = "sharded_tensor"

with mock.patch(
    "keras_hub.src.models.causal_lm.distribution_lib.distribution",
    return_value=mock_dist,
) as mock_distribution_fn:
    causal_lm = CausalLM.from_preset(...)
    causal_lm.generate("test prompt")

    mock_distribution_fn.assert_called()
    mock_dist.get_data_layout.assert_called()
    mock_dist.distribute_tensor.assert_called_with(
        mock.ANY, mock_dist_layout
    )

Please add tests to validate this new functionality.

References
  1. The contribution guidelines state that testing is a non-negotiable part of every contribution (line 403). Every file containing logic must have a corresponding test file to ensure all core functionality is covered (line 406). (link)


if strip_prompt:
outputs = [strip_prompt_function(generate(x), x) for x in inputs]
else:
Expand Down
Loading