diff --git a/jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto b/jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto index d4a70a589..12a2f342f 100644 --- a/jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto +++ b/jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto @@ -82,4 +82,4 @@ message ParamMapping { message TpModelMappingSpecs { repeated ParamMapping mappings = 1; -} \ No newline at end of file +} diff --git a/jax-inference-offloading/jax_inference_offloading/models/__init__.py b/jax-inference-offloading/jax_inference_offloading/models/__init__.py index abf2b1652..b90a2dae1 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/__init__.py +++ b/jax-inference-offloading/jax_inference_offloading/models/__init__.py @@ -32,11 +32,12 @@ def make_transform(slice=[], transpose=[], reshape=[], replication_axis=None, re def make_mapping( - jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model" + jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model", dtype="bfloat16" ): result = mapping.ParamMapping() result.vllm_param.name = f"{vllm_prefix}.{vllm_name}".lstrip(".") result.vllm_param.shape.extend(vllm_shape) + result.vllm_param.dtype = dtype result.jax_param.name = f"{jax_prefix}.{jax_name}".lstrip(".") if transform is not None: result.jax_param.transform.CopyFrom(transform) @@ -62,3 +63,4 @@ def get_named_parameters(nnx_model, prefix="model", *filters): nnx_state = nnx.state(nnx_model, *filters) return flatten_state(nnx_state, prefix=prefix) + diff --git a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py index 61e4d1394..8085fe95e 100644 --- a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py +++ b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py @@ -187,7 +187,7 @@ def update_weights(self, mapping_specs: TpModelMappingSpecs): logger.debug(f'vLLM TP rank {tp_rank} receiving {param.vllm_param.name} ...') weight = self.transport.gather( - shape, param.vllm_param.dtype or 'bfloat16', + shape, param.vllm_param.dtype, sharding_specs.aux_dim, sharding_specs.aux_parallelism ) logger.debug(f'vLLM TP rank {tp_rank} received {param.vllm_param.name} shape {weight.shape}') @@ -206,7 +206,7 @@ def update_weights(self, mapping_specs: TpModelMappingSpecs): logger.debug(f"vLLM expecting: {param.vllm_param.name} shape {shape.tolist()} raw specs {param}") - weight = self.transport.recv(shape, param.vllm_param.dtype or 'bfloat16') + weight = self.transport.recv(shape, param.vllm_param.dtype) self._staged_weights.append((param.vllm_param.name, weight)) # TODO: make it optional @@ -235,7 +235,7 @@ def update_weights_grouped(self, mapping_specs: TpModelMappingSpecs): param_specs.append(( shape, - param.vllm_param.dtype or 'bfloat16', + param.vllm_param.dtype, sharding_specs.aux_dim, sharding_specs.aux_parallelism )) @@ -264,7 +264,7 @@ def update_weights_grouped(self, mapping_specs: TpModelMappingSpecs): if sharding_specs.parallelism > 0: shape[sharding_specs.dim] //= sharding_specs.parallelism - param_specs.append((shape, param.vllm_param.dtype or 'bfloat16')) + param_specs.append((shape, param.vllm_param.dtype)) param_names.append(param.vllm_param.name) # Receive all weights in one grouped operation