@@ -117,33 +117,40 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {
117
117
118
118
fn field_init ( & self ) -> Option < TokenStream > {
119
119
match & self . value {
120
- ConstantValue :: Tensor ( tensor_type, _ ) => {
120
+ ConstantValue :: Tensor ( tensor_type, data ) => {
121
121
let ty = tensor_type. ty ( ) ;
122
122
let name = Ident :: new ( self . name . as_ref ( ) , Span :: call_site ( ) ) ;
123
- let shape = tensor_type. clone ( ) . shape . unwrap ( ) . to_tokens ( ) ;
124
- let dim = tensor_type. rank . to_tokens ( ) ;
123
+
124
+ assert_eq ! (
125
+ data. shape. len( ) ,
126
+ tensor_type. rank,
127
+ "Tensor data shape does not match tensor type rank"
128
+ ) ;
129
+
130
+ let shape = data. shape . to_tokens ( ) ;
131
+ let rank = tensor_type. rank . to_tokens ( ) ;
125
132
126
133
match tensor_type. kind {
127
134
crate :: burn:: TensorKind :: Int => Some ( quote ! {
128
135
let #name: burn:: module:: Param <#ty> = burn:: module:: Param :: uninitialized(
129
136
burn:: module:: ParamId :: new( ) ,
130
- move |device, _require_grad| Tensor :: <B , #dim , Int >:: zeros( #shape, & device) ,
137
+ move |device, _require_grad| Tensor :: <B , #rank , Int >:: zeros( #shape, & device) ,
131
138
device. clone( ) ,
132
139
false
133
140
) ;
134
141
} ) ,
135
142
crate :: burn:: TensorKind :: Float => Some ( quote ! {
136
143
let #name: burn:: module:: Param <#ty> = burn:: module:: Param :: uninitialized(
137
144
burn:: module:: ParamId :: new( ) ,
138
- move |device, _require_grad| Tensor :: <B , #dim >:: zeros( #shape, & device) ,
145
+ move |device, _require_grad| Tensor :: <B , #rank >:: zeros( #shape, & device) ,
139
146
device. clone( ) ,
140
147
false ,
141
148
) ;
142
149
} ) ,
143
150
crate :: burn:: TensorKind :: Bool => Some ( quote ! {
144
151
let #name: burn:: module:: Param <#ty> = burn:: module:: Param :: uninitialized(
145
152
burn:: module:: ParamId :: new( ) ,
146
- move |device, _require_grad| Tensor :: <B , #dim , Bool >:: empty( #shape, & device) ,
153
+ move |device, _require_grad| Tensor :: <B , #rank , Bool >:: empty( #shape, & device) ,
147
154
device. clone( ) ,
148
155
false ,
149
156
) ;
@@ -288,23 +295,14 @@ mod tests {
288
295
289
296
let const_tensor = Ident :: new ( "const_tensor" , Span :: call_site ( ) ) ;
290
297
let dimensions = 1 ;
291
- let shape = vec ! [ 4 ] ;
292
298
let data = TensorData :: from ( [ 2f32 , 2f32 , 2f32 , 2f32 ] ) ;
293
- let tensor_type = TensorType :: new_float_with_shape (
294
- const_tensor. to_string ( ) ,
295
- dimensions,
296
- Some ( shape. clone ( ) ) ,
297
- ) ;
299
+ let tensor_type = TensorType :: new_float ( const_tensor. to_string ( ) , dimensions) ;
298
300
let constant = ConstantValue :: Tensor ( tensor_type. clone ( ) , data) ;
299
301
300
302
graph. register ( ConstantNode :: new (
301
303
const_tensor. to_string ( ) ,
302
304
constant. clone ( ) ,
303
- Type :: Tensor ( TensorType :: new_float_with_shape (
304
- "output" ,
305
- dimensions,
306
- Some ( shape. clone ( ) ) ,
307
- ) ) ,
305
+ Type :: Tensor ( TensorType :: new_float ( "output" , dimensions) ) ,
308
306
) ) ;
309
307
310
308
graph. register_input_output ( vec ! [ ] , vec ! [ "output" . to_string( ) ] ) ;
@@ -356,23 +354,14 @@ mod tests {
356
354
357
355
let const_tensor = Ident :: new ( "const_tensor_int" , Span :: call_site ( ) ) ;
358
356
let dimensions = 1 ;
359
- let shape = vec ! [ 3 ] ;
360
357
let data = TensorData :: from ( [ 1i32 , 2i32 , 3i32 ] ) ;
361
- let tensor_type = TensorType :: new_int_with_shape (
362
- const_tensor. to_string ( ) ,
363
- dimensions,
364
- Some ( shape. clone ( ) ) ,
365
- ) ;
358
+ let tensor_type = TensorType :: new_int ( const_tensor. to_string ( ) , dimensions) ;
366
359
let constant = ConstantValue :: Tensor ( tensor_type. clone ( ) , data) ;
367
360
368
361
graph. register ( ConstantNode :: new (
369
362
const_tensor. to_string ( ) ,
370
363
constant. clone ( ) ,
371
- Type :: Tensor ( TensorType :: new_int_with_shape (
372
- "output" ,
373
- dimensions,
374
- Some ( shape. clone ( ) ) ,
375
- ) ) ,
364
+ Type :: Tensor ( TensorType :: new_int ( "output" , dimensions) ) ,
376
365
) ) ;
377
366
378
367
graph. register_input_output ( vec ! [ ] , vec ! [ "output" . to_string( ) ] ) ;
@@ -425,23 +414,14 @@ mod tests {
425
414
426
415
let const_tensor = Ident :: new ( "const_tensor_3d" , Span :: call_site ( ) ) ;
427
416
let dimensions = 3 ;
428
- let shape = vec ! [ 1 , 3 , 2 ] ;
429
417
let data = TensorData :: from ( [ [ [ true , false ] , [ true , false ] , [ true , false ] ] ] ) ;
430
- let tensor_type = TensorType :: new_bool_with_shape (
431
- const_tensor. to_string ( ) ,
432
- dimensions,
433
- Some ( shape. clone ( ) ) ,
434
- ) ;
418
+ let tensor_type = TensorType :: new_bool ( const_tensor. to_string ( ) , dimensions) ;
435
419
let constant = ConstantValue :: Tensor ( tensor_type. clone ( ) , data) ;
436
420
437
421
graph. register ( ConstantNode :: new (
438
422
const_tensor. to_string ( ) ,
439
423
constant. clone ( ) ,
440
- Type :: Tensor ( TensorType :: new_bool_with_shape (
441
- "output" ,
442
- dimensions,
443
- Some ( shape. clone ( ) ) ,
444
- ) ) ,
424
+ Type :: Tensor ( TensorType :: new_bool ( "output" , dimensions) ) ,
445
425
) ) ;
446
426
447
427
graph. register_input_output ( vec ! [ ] , vec ! [ "output" . to_string( ) ] ) ;
0 commit comments