@@ -176,8 +176,7 @@ def forward(
176176 return self .func .forward (dfk , ifk , self .y , self .z )
177177
178178
179- class MixedDim (KinematicElement ):
180- # TODO true mixed dim that works with any element?
179+ class MixedDimKinematic (KinematicElement ):
181180 def __init__ (self , module_2d : nn .Module , module_3d : nn .Module ):
182181 super ().__init__ ()
183182 self .module_2d = module_2d
@@ -195,7 +194,7 @@ def __init__(self, offset: Float[torch.Tensor, ""] | float | int):
195194 super ().__init__ ()
196195 translate_2d = Translate2D (x = offset )
197196 translate_3d = Translate3D (x = offset )
198- self .mixed_dim = MixedDim (translate_2d , translate_3d )
197+ self .mixed_dim = MixedDimKinematic (translate_2d , translate_3d )
199198
200199 def forward (self , dfk : HomMatrix , ifk : HomMatrix ) -> tuple [HomMatrix , HomMatrix ]:
201200 return self .mixed_dim (dfk , ifk )
@@ -206,7 +205,7 @@ def __init__(
206205 self , angles : tuple [float | int , float | int ] | Float [torch .Tensor , "2" ]
207206 ):
208207 super ().__init__ ()
209- self .mixed_dim = MixedDim (Rotate2D (angles [0 ]), Rotate3D (angles [1 ], angles [0 ]))
208+ self .mixed_dim = MixedDimKinematic (Rotate2D (angles [0 ]), Rotate3D (angles [1 ], angles [0 ]))
210209
211210 def forward (self , dfk : HomMatrix , ifk : HomMatrix ) -> tuple [HomMatrix , HomMatrix ]:
212211 return self .mixed_dim (dfk , ifk )
@@ -220,7 +219,7 @@ def __init__(
220219 z : Float [torch .Tensor , "" ] | float | int = 0.0 ,
221220 ):
222221 super ().__init__ ()
223- self .mixed_dim = MixedDim (Translate2D (x , y ), Translate3D (x , y , z ))
222+ self .mixed_dim = MixedDimKinematic (Translate2D (x , y ), Translate3D (x , y , z ))
224223
225224 def forward (self , dfk : HomMatrix , ifk : HomMatrix ) -> tuple [HomMatrix , HomMatrix ]:
226225 return self .mixed_dim (dfk , ifk )
0 commit comments