@@ -17,9 +17,11 @@ class Normalization(nn.Module):
1717
1818 Args:
1919 method: Normalization method to use. Options are: "layer", "batch",
20- "rmsnorm", "unitball", "unitball-detach", "none". "unitball" is
21- (x / ||x||), "unitball-detach" is (x / ||x||.detach()). "none" is a
22- no-op and the rest are standard LayerNorm, BatchNorm, RMSNorm.
20+ "rmsnorm", "unitball", "unitball-detach", "hypersphere", "none".
21+ "unitball" is (x / ||x||), "unitball-detach" is (x / ||x||.detach()),
22+ "hypersphere" is (x / ||x|| * sqrt(d)) projecting onto S^{d-1} with
23+ ||x|| = sqrt(d). "none" is a no-op and the rest are standard
24+ LayerNorm, BatchNorm, RMSNorm.
2325 d_model: Expected dimension of the input to normalize (scalar). Operates
2426 on the last dimensions of the input sequence.
2527 """
@@ -43,6 +45,8 @@ def __init__(self, method: Optional[str], d_model: int):
4345 torch .linalg .vector_norm (x , ord = 2 , dim = - 1 , keepdim = True ) + 1e-5
4446 ).detach ()
4547 )
48+ elif method == "hypersphere" :
49+ self .norm = lambda x : F .normalize (x , dim = - 1 ) * (x .shape [- 1 ] ** 0.5 )
4650 elif method == "rmsnorm" :
4751 self .norm = _RMSNorm (size = d_model )
4852 elif method == "simnorm" :
0 commit comments