@@ -426,38 +426,14 @@ def _slice_embedding(child, name, conv_linear_layer):
426
426
def update_mp_params (child ):
427
427
if getattr (child , "replaced" , False ) == True :
428
428
return
429
- if hasattr (child , 'n_heads' ):
430
- assert child .n_heads % mp_size == 0 , "n_heads ({}) must be divisible by mp_size ({})" .format (
431
- child .n_heads , mp_size )
432
- child .n_heads = child .n_heads // mp_size
433
- if hasattr (child , 'inner_dim' ):
434
- assert child .inner_dim % mp_size == 0 , "inner_dim ({}) must be divisible by mp_size ({})" .format (
435
- child .inner_dim , mp_size )
436
- child .inner_dim = child .inner_dim // mp_size
437
- if hasattr (child , 'num_heads' ):
438
- assert child .num_heads % mp_size == 0 , "num_heads ({}) must be divisible by mp_size ({})" .format (
439
- child .num_heads , mp_size )
440
- child .num_heads = child .num_heads // mp_size
441
- if hasattr (child , 'num_attention_heads' ):
442
- assert child .num_attention_heads % mp_size == 0 , "num_attention_heads ({}) must be divisible by mp_size ({})" .format (
443
- child .num_attention_heads , mp_size )
444
- child .num_attention_heads = child .num_attention_heads // mp_size
445
- if hasattr (child , 'num_attn_heads' ):
446
- assert child .num_attn_heads % mp_size == 0 , "num_attn_heads ({}) must be divisible by mp_size ({})" .format (
447
- child .num_attn_heads , mp_size )
448
- child .num_attn_heads = child .num_attn_heads // mp_size
449
- if hasattr (child , 'all_head_size' ):
450
- assert child .all_head_size % mp_size == 0 , "all_head_size ({}) must be divisible by mp_size ({})" .format (
451
- child .all_head_size , mp_size )
452
- child .all_head_size = child .all_head_size // mp_size
453
- if hasattr (child , 'embed_dim' ):
454
- assert child .embed_dim % mp_size == 0 , "embed_dim must ({}) be divisible by mp_size ({})" .format (
455
- child .embed_dim , mp_size )
456
- child .embed_dim = child .embed_dim // mp_size
457
- if hasattr (child , 'hidden_size' ):
458
- assert child .hidden_size % mp_size == 0 , "hidden_size ({}) must be divisible by mp_size ({})" .format (
459
- child .hidden_size , mp_size )
460
- child .hidden_size = child .hidden_size // mp_size
429
+ for param in [
430
+ "n_heads" , "inner_dim" , "num_heads" , "num_kv" , "num_attention_heads" , "num_attn_heads" ,
431
+ "all_head_size" , "embed_dim" , "hidden_size"
432
+ ]:
433
+ if hasattr (child , param ):
434
+ param_val = getattr (child , param )
435
+ assert param_val % mp_size == 0 , f"{ param } ({ param_val } ) must be divisible by mp_size ({ mp_size } )"
436
+ setattr (child , param , param_val // mp_size )
461
437
setattr (child , "replaced" , True )
462
438
463
439
conv_linear_layer = False
@@ -495,6 +471,16 @@ def _replace_module(r_module, prev_name='', prev_class_name=''):
495
471
if child .__class__ in linear_policies :
496
472
setattr (r_module , name , linear_policies [child .__class__ ](child , prev_name + '.' + name ,
497
473
conv_linear_layer ))
474
+ elif any (isinstance (child , lp ) for lp in linear_policies ):
475
+ # Added for falcon model support
476
+ # Note: isinstance will account for class inheritance, child.__class__ does not
477
+ key = None
478
+ for lp in linear_policies :
479
+ if isinstance (child , lp ):
480
+ key = lp
481
+ break
482
+ assert key is not None
483
+ setattr (r_module , name , linear_policies [key ](child , prev_name + '.' + name , conv_linear_layer ))
498
484
else :
499
485
update_mp_params (child )
500
486
_replace_module (child , name , class_name )
0 commit comments