Skip to content

Commit f3c93b0

Browse files
Add FALCON Auto-TP Support (#3640)
* Add FALCON auto-tp support * added (skipped) unit test, refactored code to be more readable --------- Co-authored-by: Michael Wyatt <[email protected]>
1 parent 385e89d commit f3c93b0

File tree

3 files changed

+63
-33
lines changed

3 files changed

+63
-33
lines changed

deepspeed/module_inject/auto_tp.py

+4
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def tp_parser(model):
108108
gem_list = gem_list + [layer]
109109
elif 'down_proj' in layer:
110110
gem_list = gem_list + [layer]
111+
elif 'self_attention.dense' in layer and 'falcon' in str(
112+
type(module)): # this is a hack to get the right linear layer for this model!
113+
gem_list = gem_list + [layer]
114+
111115
layer_list = []
112116
if gem_list != []:
113117
gem_list = list(set(gem_list))

deepspeed/module_inject/replace_module.py

+18-32
Original file line numberDiff line numberDiff line change
@@ -426,38 +426,14 @@ def _slice_embedding(child, name, conv_linear_layer):
426426
def update_mp_params(child):
427427
if getattr(child, "replaced", False) == True:
428428
return
429-
if hasattr(child, 'n_heads'):
430-
assert child.n_heads % mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format(
431-
child.n_heads, mp_size)
432-
child.n_heads = child.n_heads // mp_size
433-
if hasattr(child, 'inner_dim'):
434-
assert child.inner_dim % mp_size == 0, "inner_dim ({}) must be divisible by mp_size ({})".format(
435-
child.inner_dim, mp_size)
436-
child.inner_dim = child.inner_dim // mp_size
437-
if hasattr(child, 'num_heads'):
438-
assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format(
439-
child.num_heads, mp_size)
440-
child.num_heads = child.num_heads // mp_size
441-
if hasattr(child, 'num_attention_heads'):
442-
assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format(
443-
child.num_attention_heads, mp_size)
444-
child.num_attention_heads = child.num_attention_heads // mp_size
445-
if hasattr(child, 'num_attn_heads'):
446-
assert child.num_attn_heads % mp_size == 0, "num_attn_heads ({}) must be divisible by mp_size ({})".format(
447-
child.num_attn_heads, mp_size)
448-
child.num_attn_heads = child.num_attn_heads // mp_size
449-
if hasattr(child, 'all_head_size'):
450-
assert child.all_head_size % mp_size == 0, "all_head_size ({}) must be divisible by mp_size ({})".format(
451-
child.all_head_size, mp_size)
452-
child.all_head_size = child.all_head_size // mp_size
453-
if hasattr(child, 'embed_dim'):
454-
assert child.embed_dim % mp_size == 0, "embed_dim must ({}) be divisible by mp_size ({})".format(
455-
child.embed_dim, mp_size)
456-
child.embed_dim = child.embed_dim // mp_size
457-
if hasattr(child, 'hidden_size'):
458-
assert child.hidden_size % mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format(
459-
child.hidden_size, mp_size)
460-
child.hidden_size = child.hidden_size // mp_size
429+
for param in [
430+
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
431+
"all_head_size", "embed_dim", "hidden_size"
432+
]:
433+
if hasattr(child, param):
434+
param_val = getattr(child, param)
435+
assert param_val % mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({mp_size})"
436+
setattr(child, param, param_val // mp_size)
461437
setattr(child, "replaced", True)
462438

463439
conv_linear_layer = False
@@ -495,6 +471,16 @@ def _replace_module(r_module, prev_name='', prev_class_name=''):
495471
if child.__class__ in linear_policies:
496472
setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name,
497473
conv_linear_layer))
474+
elif any(isinstance(child, lp) for lp in linear_policies):
475+
# Added for falcon model support
476+
# Note: isinstance will account for class inheritance, child.__class__ does not
477+
key = None
478+
for lp in linear_policies:
479+
if isinstance(child, lp):
480+
key = lp
481+
break
482+
assert key is not None
483+
setattr(r_module, name, linear_policies[key](child, prev_name + '.' + name, conv_linear_layer))
498484
else:
499485
update_mp_params(child)
500486
_replace_module(child, name, class_name)

tests/unit/inference/test_inference.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from unit.common import DistributedTest
1414
from packaging import version as pkg_version
1515
from deepspeed.ops.op_builder import OpBuilder
16-
from transformers import pipeline
16+
from transformers import pipeline, AutoTokenizer
1717
from transformers.models.t5.modeling_t5 import T5Block
1818
from transformers.models.roberta.modeling_roberta import RobertaLayer
1919
from huggingface_hub import HfApi
@@ -380,6 +380,46 @@ def test(
380380
assert assert_fn(bs_output, ds_output)
381381

382382

383+
@pytest.mark.seq_inference
384+
@pytest.mark.parametrize("model_w_task", [("tiiuae/falcon-7b", "text-generation")], ids=["falcon"])
385+
class TestAutoTP(DistributedTest):
386+
world_size = 1
387+
388+
def test(
389+
self,
390+
model_w_task,
391+
query,
392+
inf_kwargs,
393+
assert_fn,
394+
):
395+
# TODO: enable this test for H100 tests
396+
pytest.skip("Not enough GPU memory for this on V100 runners")
397+
model, task = model_w_task
398+
dtype = torch.bfloat16
399+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
400+
401+
# We have to load these large models on CPU with pipeline because not
402+
# enough GPU memory
403+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
404+
pipe = pipeline(task,
405+
model=model,
406+
tokenizer=tokenizer,
407+
torch_dtype=dtype,
408+
trust_remote_code=True,
409+
device=torch.device("cpu"),
410+
framework="pt")
411+
#bs_output = pipe(query, **inf_kwargs)
412+
413+
pipe.model = deepspeed.init_inference(pipe.model, mp_size=self.world_size, replace_with_kernel_inject=False)
414+
# Switch device to GPU so that input tensors are not on CPU
415+
pipe.device = torch.device(get_accelerator().device_name(local_rank))
416+
ds_output = pipe(query, **inf_kwargs)
417+
418+
#print(local_rank, "baseline", bs_output)
419+
print(local_rank, "deepspeed", ds_output)
420+
#assert assert_fn(bs_output, ds_output)
421+
422+
383423
@pytest.mark.seq_inference
384424
@pytest.mark.parametrize(
385425
"model_w_task, injection_policy",

0 commit comments

Comments
 (0)