- 
                Notifications
    
You must be signed in to change notification settings  - Fork 1.1k
 
[muP] Rework #1087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[muP] Rework #1087
Changes from 31 commits
0d921f7
              abee54d
              a08c3ef
              c35e830
              2807e52
              2d127df
              81fdc4d
              7d6b246
              a0d1929
              d63b3b8
              9be82fe
              66214d9
              17b7183
              a6bad07
              63984bd
              02687a8
              11114e2
              05c4de3
              71a91e4
              99c8ce0
              b253ab6
              1919499
              17678e0
              2bd5ae6
              6642291
              8be6c66
              c9fb18b
              795371c
              087beee
              16d04b1
              e7b7bf6
              8dea9ce
              3664eba
              6a46247
              7439f9a
              5b2d31c
              98caa82
              5c99637
              a636f06
              39190c5
              2489cc0
              23b8776
              10e935e
              a0aca99
              9472b35
              5c5f2df
              7eca3e7
              c9a3a65
              a7877d4
              bd9d399
              fe180d3
              b240c19
              93b4241
              d4899fc
              f589e29
              8261e0d
              25aa786
              4d246a1
              84c5380
              b2f1101
              42d4cde
              4c477d5
              64dc4c5
              65c103e
              08b5d40
              2ca94a8
              7483246
              497485c
              34fb7ca
              2d53f1f
              7897610
              4f39209
              5f84a3f
              479b854
              21a7e32
              bb2e0c9
              bf1ce06
              50a3dba
              c4c1660
              1c35911
              84be4d4
              fbb4daf
              fa142ff
              ad2336f
              fe73bc3
              a3bd44c
              56b6c9b
              2365fd5
              e8639a0
              47e1438
              a064f9b
              b0da27a
              6fe55f4
              8bf8bcd
              7f0b033
              f802869
              cc71104
              c8feb39
              b6b3a02
              847e892
              1b0027c
              055596f
              fabb45b
              5ccf693
              9dd583b
              6a8ad71
              485cad4
              c291906
              1ac9add
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -16,41 +16,22 @@ | |
| 
     | 
||
| import torch | ||
| 
     | 
||
| try: | ||
| import mup | ||
| except ImportError: | ||
| pass | ||
| 
     | 
||
| 
     | 
||
| def init_method_normal(sigma, use_mup_outer=False, mup_init_scale=1.0): | ||
| def init_method_normal(sigma): | ||
| """Init method based on N(0, sigma).""" | ||
                
      
                  lintangsutawika marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| def init_(tensor, use_mup=use_mup_outer): | ||
| if use_mup: | ||
| mup.init.normal_(tensor, mean=0.0, std=sigma) | ||
| with torch.no_grad(): | ||
| tensor.mul_(mup_init_scale) | ||
| return tensor | ||
| else: | ||
| return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) | ||
| def init_(tensor): | ||
| return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) | ||
| 
     | 
||
| return init_ | ||
| 
     | 
||
| 
     | 
||
| def scaled_init_method_normal( | ||
| sigma, num_layers, use_mup_outer=False, mup_init_scale=1.0 | ||
| ): | ||
| def scaled_init_method_normal(sigma, num_layers): | ||
| """Init method based on N(0, sigma/sqrt(2*num_layers).""" | ||
| std = sigma / math.sqrt(2.0 * num_layers) | ||
| 
     | 
||
| def init_(tensor, use_mup=use_mup_outer): | ||
| if use_mup: | ||
| mup.init.normal_(tensor, mean=0.0, std=std) | ||
| with torch.no_grad(): | ||
| tensor.mul_(mup_init_scale) | ||
| return tensor | ||
| else: | ||
| return torch.nn.init.normal_(tensor, mean=0.0, std=std) | ||
| def init_(tensor): | ||
| return torch.nn.init.normal_(tensor, mean=0.0, std=std) | ||
| 
     | 
||
| return init_ | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -87,12 +68,12 @@ def _orthogonal(tensor, gain=1): | |
| return tensor | ||
| 
     | 
||
| 
     | 
||
| def orthogonal_init_method(n_layers=1, use_mup=False, mup_init_scale=1.0): | ||
| def orthogonal_init_method(n_layers=1, mup_m_width=1.0): | ||
| """Fills the input Tensor with a (semi) orthogonal matrix, as described in | ||
| Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013) | ||
| Optionally scaling by number of layers possible, as introduced in OBST - Nestler et. al. (2021, to be released)""" | ||
| 
     | 
||
| if use_mup: | ||
| if mup_m_width != 1: | ||
| raise ValueError( | ||
| "Orthogonal init needs to be patched to support mup. Disable mup or use a different init method to avoid this error" | ||
| ) | ||
| 
        
          
        
         | 
    @@ -103,105 +84,91 @@ def init_(tensor): | |
| return init_ | ||
| 
     | 
||
| 
     | 
||
| def xavier_uniform_init_method(use_mup_outer=False, mup_init_scale=1.0): | ||
| def xavier_uniform_init_method(mup_m_width=1.0): | ||
| """Fills the input Tensor with values according to the method described in Understanding the difficulty of | ||
| training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution.""" | ||
| 
     | 
||
| def init_(tensor, use_mup=use_mup_outer): | ||
| if use_mup: | ||
| mup.init.xavier_uniform_(tensor) | ||
| def init_(tensor, mup_m_width=mup_m_width): | ||
| init_weight = torch.nn.init.xavier_uniform_(tensor) | ||
| if mup_m_width != 1: | ||
| with torch.no_grad(): | ||
| tensor.mul_(mup_init_scale) | ||
| return tensor | ||
| else: | ||
| return torch.nn.init.xavier_uniform_(tensor) | ||
| init_weight.div_(mup_m_width) | ||
                
      
                  lintangsutawika marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| return init_weight | ||
| 
     | 
||
| return init_ | ||
| 
     | 
||
| 
     | 
||
| def xavier_normal_init_method(use_mup_outer=False, mup_init_scale=1.0): | ||
| def xavier_normal_init_method(mup_m_width=1.0): | ||
| """Fills the input Tensor with values according to the method described in Understanding the difficulty of | ||
| training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution.""" | ||
| 
     | 
||
| def init_(tensor, use_mup=use_mup_outer): | ||
| if use_mup: | ||
| mup.init.xavier_normal_(tensor) | ||
| def init_(tensor, mup_m_width=mup_m_width): | ||
| init_weight = torch.nn.init.xavier_normal_(tensor) | ||
| if mup_m_width != 1: | ||
| with torch.no_grad(): | ||
| tensor.mul_(mup_init_scale) | ||
| return tensor | ||
| else: | ||
| return torch.nn.init.xavier_normal_(tensor) | ||
| init_weight.div_(mup_m_width) | ||
                
      
                  lintangsutawika marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| return init_weight | ||
| 
     | 
||
| return init_ | ||
| 
     | 
||
| 
     | 
||
| def small_init_init_method(dim, use_mup_outer=False, mup_init_scale=1.0): | ||
| def small_init_init_method(dim, mup_m_width=1.0): | ||
| """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving | ||
| the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution.""" | ||
| std = math.sqrt(2 / (5 * dim)) | ||
| 
     | 
||
| def init_(tensor, use_mup=use_mup_outer): | ||
| if use_mup: | ||
| mup.init.normal_(tensor, mean=0.0, std=std) | ||
| def init_(tensor, mup_m_width=mup_m_width): | ||
| init_weight = torch.nn.init.normal_(tensor, mean=0.0, std=std) | ||
| if mup_m_width != 1: | ||
| with torch.no_grad(): | ||
| tensor.mul_(mup_init_scale) | ||
| return tensor | ||
| else: | ||
| return torch.nn.init.normal_(tensor, mean=0.0, std=std) | ||
| init_weight.div_(mup_m_width) | ||
                
      
                  lintangsutawika marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| return init_weight | ||
| 
     | 
||
| return init_ | ||
| 
     | 
||
| 
     | 
||
| def wang_init_method(n_layers, dim, use_mup_outer=False, mup_init_scale=1.0): | ||
| def wang_init_method(n_layers, dim, mup_m_width=1.0): | ||
| std = 2 / n_layers / math.sqrt(dim) | ||
| 
     | 
||
| def init_(tensor, use_mup=use_mup_outer): | ||
| if use_mup: | ||
| mup.init.normal_(tensor, mean=0.0, std=std) | ||
| def init_(tensor, mup_m_width=mup_m_width): | ||
| init_weight = torch.nn.init.normal_(tensor, mean=0.0, std=std) | ||
| if mup_m_width != 1: | ||
| with torch.no_grad(): | ||
| tensor.mul_(mup_init_scale) | ||
| return tensor | ||
| else: | ||
| return torch.nn.init.normal_(tensor, mean=0.0, std=std) | ||
| 
     | 
||
| init_weight.div_(mup_m_width) | ||
                
      
                  lintangsutawika marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| return init_weight | ||
| 
     | 
||
| return init_ | ||
| 
     | 
||
| 
     | 
||
| def get_init_methods(args): | ||
| 
     | 
||
| if args.use_mup: | ||
| try: | ||
| import mup | ||
| except ModuleNotFoundError: | ||
| print("Please install mup https://github.com/microsoft/mup") | ||
| raise Exception | ||
| 
     | 
||
| def _get(name): | ||
| if name == "normal": | ||
| return init_method_normal( | ||
| args.init_method_std, args.use_mup, args.mup_init_scale | ||
| sigma=args.init_method_std/math.sqrt(args.mup_m_width) | ||
                
       | 
||
| ) | ||
| elif name == "scaled_normal": | ||
| return scaled_init_method_normal( | ||
| args.init_method_std, args.num_layers, args.use_mup, args.mup_init_scale | ||
| sigma=args.init_method_std/math.sqrt(args.mup_m_width), | ||
| num_layers=args.num_layers | ||
| ) | ||
| elif name == "orthogonal": | ||
| return orthogonal_init_method(args.use_mup, args.mup_init_scale) | ||
| return orthogonal_init_method(args.mup_m_width) | ||
| elif name == "scaled_orthogonal": | ||
| return orthogonal_init_method( | ||
| args.num_layers, args.use_mup, args.mup_init_scale | ||
| args.num_layers, args.mup_m_width | ||
| ) | ||
| elif name == "xavier_uniform": | ||
| return xavier_uniform_init_method(args.use_mup, args.mup_init_scale) | ||
| return xavier_uniform_init_method(args.mup_m_width) | ||
| elif name == "xavier_normal": | ||
| return xavier_normal_init_method(args.use_mup, args.mup_init_scale) | ||
| return xavier_normal_init_method(args.mup_m_width) | ||
| elif name == "wang_init": | ||
| return wang_init_method( | ||
| args.num_layers, args.hidden_size, args.use_mup, args.mup_init_scale | ||
| args.num_layers, args.hidden_size, args.mup_m_width | ||
| ) | ||
| elif name == "small_init": | ||
| return small_init_init_method( | ||
| args.hidden_size, args.use_mup, args.mup_init_scale | ||
| args.hidden_size, args.mup_m_width | ||
| ) | ||
| else: | ||
| raise NotImplementedError(f"Unknown init method {name}") | ||
| 
          
            
          
           | 
    ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -306,13 +306,13 @@ def __init__( | |
| ) | ||
| 
     | 
||
| coeff = None | ||
| self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) | ||
| if self.apply_query_key_layer_scaling: | ||
| coeff = max(1, self.layer_number) | ||
| self.norm_factor *= coeff | ||
| 
     | 
||
| if neox_args.use_mup: | ||
| self.norm_factor = self.hidden_size_per_attention_head | ||
| else: | ||
| self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) | ||
| if self.apply_query_key_layer_scaling: | ||
| coeff = max(1, self.layer_number) | ||
| self.norm_factor *= coeff | ||
| 
     | 
||
| self.rpe = rpe | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -960,7 +960,7 @@ def forward(self, args): | |
| return self.norm(args) | ||
| 
     | 
||
| 
     | 
||
| def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): | ||
| def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None, args=None): | ||
| """LM logits using word embedding weights.""" | ||
| # Parallel logits. | ||
| input_parallel = mpu.copy_to_model_parallel_region(input_) | ||
| 
        
          
        
         | 
    @@ -971,6 +971,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non | |
| else: | ||
| logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) | ||
| 
     | 
||
| if args is not None and args.use_mup: | ||
| logits_parallel /= args.mup_m_width | ||
                
       | 
||
| 
     | 
||
| # Gather if needed. | ||
| if parallel_output: | ||
| return logits_parallel | ||
| 
          
            
          
           | 
    ||
Uh oh!
There was an error while loading. Please reload this page.