1111# that they have been altered from the originals.
1212
1313"""Base class for analog Modules."""
14+ import warnings
1415
1516from typing import (
1617 Any , Dict , List , Optional , Tuple , NamedTuple , Union ,
@@ -72,14 +73,15 @@ class AnalogModuleBase(Module):
7273 ANALOG_CTX_PREFIX : str = 'analog_ctx_'
7374 ANALOG_SHARED_WEIGHT_PREFIX : str = 'analog_shared_weights_'
7475 ANALOG_STATE_PREFIX : str = 'analog_tile_state_'
76+ ANALOG_OUT_SCALING_ALPHA_PREFIX : str = 'analog_out_scaling_alpha_'
7577
7678 def __init__ (
7779 self ,
7880 in_features : int ,
7981 out_features : int ,
8082 bias : bool ,
8183 realistic_read_write : bool = False ,
82- weight_scaling_omega : float = 0.0 ,
84+ weight_scaling_omega : Optional [ float ] = None ,
8385 mapping : Optional [MappingParameter ] = None ,
8486 ) -> None :
8587 # pylint: disable=super-init-not-called
@@ -93,9 +95,21 @@ def __init__(
9395 self .use_bias = bias
9496 self .digital_bias = bias and mapping .digital_bias
9597 self .analog_bias = bias and not mapping .digital_bias
98+ self .weight_scaling_omega = mapping .weight_scaling_omega if weight_scaling_omega is None \
99+ else weight_scaling_omega
100+ if weight_scaling_omega is not None :
101+ warnings .warn (DeprecationWarning ('\n Setting the weight_scaling_omega through the '
102+ 'layers input parameters will be deprecated in the '
103+ 'future. Please set it through the MappingParameter '
104+ 'of the rpu_config.\n ' ))
105+
106+ self .weight_scaling_omega_columnwise = mapping .weight_scaling_omega_columnwise
107+ self .learn_out_scaling_alpha = mapping .learn_out_scaling_alpha
108+
109+ if self .learn_out_scaling_alpha and self .weight_scaling_omega == 0 :
110+ raise ValueError ('out_scaling_alpha can only be learned if weight_scaling_omega > 0' )
96111
97112 self .realistic_read_write = realistic_read_write
98- self .weight_scaling_omega = weight_scaling_omega
99113 self .in_features = in_features
100114 self .out_features = out_features
101115
@@ -129,6 +143,15 @@ def register_analog_tile(self, tile: 'BaseTile', name: Optional[str] = None) ->
129143 if par_name not in self ._registered_helper_parameter :
130144 self ._registered_helper_parameter .append (par_name )
131145
146+ if self .learn_out_scaling_alpha :
147+ if not isinstance (tile .out_scaling_alpha , Parameter ):
148+ tile .out_scaling_alpha = Parameter (tile .out_scaling_alpha )
149+ par_name = self .ANALOG_OUT_SCALING_ALPHA_PREFIX + str (self ._analog_tile_counter )
150+ self .register_parameter (par_name , tile .out_scaling_alpha )
151+
152+ if par_name not in self ._registered_helper_parameter :
153+ self ._registered_helper_parameter .append (par_name )
154+
132155 self ._analog_tile_counter += 1
133156
134157 def unregister_parameter (self , param_name : str ) -> None :
@@ -235,9 +258,12 @@ def set_weights(
235258 analog_tile = analog_tiles [0 ]
236259
237260 if self .weight_scaling_omega > 0.0 :
238- analog_tile .set_weights_scaled (weight , bias if self .analog_bias else None ,
239- realistic = realistic ,
240- omega = self .weight_scaling_omega )
261+ analog_tile .set_weights_scaled (
262+ weight , bias if self .analog_bias else None ,
263+ realistic = realistic ,
264+ omega = self .weight_scaling_omega ,
265+ weight_scaling_omega_columnwise = self .weight_scaling_omega_columnwise ,
266+ learn_out_scaling_alpha = self .learn_out_scaling_alpha )
241267 else :
242268 analog_tile .set_weights (weight , bias if self .analog_bias else None ,
243269 realistic = realistic )
@@ -283,7 +309,9 @@ def get_weights(
283309
284310 realistic = self .realistic_read_write and not force_exact
285311 if self .weight_scaling_omega > 0.0 :
286- weight , bias = analog_tile .get_weights_scaled (realistic = realistic )
312+ weight , bias = analog_tile .get_weights_scaled (
313+ realistic = realistic ,
314+ weight_scaling_omega_columnwise = self .weight_scaling_omega_columnwise )
287315 else :
288316 weight , bias = analog_tile .get_weights (realistic = realistic )
289317
0 commit comments