@@ -59,7 +59,7 @@ def size(self, dim: int) -> int:
59
59
return self .num_rows
60
60
elif dim == 1 :
61
61
return self .num_cols
62
- assert False , "Should not reach here."
62
+ raise AssertionError ( "Should not reach here." )
63
63
64
64
def dim (self ) -> int :
65
65
return self .ndim
@@ -243,7 +243,7 @@ def index_select(self, index: Tensor, dim: int) -> _MultiTensor:
243
243
return self ._row_index_select (idx )
244
244
elif dim == 1 :
245
245
return self ._col_index_select (idx )
246
- assert False , "Should not reach here."
246
+ raise AssertionError ( "Should not reach here." )
247
247
248
248
def _row_index_select (self , index : Tensor ) -> _MultiTensor :
249
249
raise NotImplementedError
@@ -300,7 +300,7 @@ def narrow(self, dim: int, start: int, length: int) -> _MultiTensor:
300
300
return self ._row_narrow (start , length )
301
301
elif dim == 1 :
302
302
return self ._col_narrow (start , length )
303
- assert False , "Should not reach here."
303
+ raise AssertionError ( "Should not reach here." )
304
304
305
305
def _row_narrow (self , start : int , length : int ) -> _MultiTensor :
306
306
raise NotImplementedError
@@ -339,7 +339,7 @@ def select(
339
339
torch .tensor (index , dtype = torch .long , device = self .device ),
340
340
dim = dim ,
341
341
)
342
- assert False , "Should not reach here."
342
+ raise AssertionError ( "Should not reach here." )
343
343
344
344
def _single_index_select (self , index : int , dim : int ) -> _MultiTensor :
345
345
raise NotImplementedError
0 commit comments