@@ -1575,107 +1575,6 @@ void pad_grad(const Tensor& input,
1575
1575
}
1576
1576
}
1577
1577
1578
- template <typename T>
1579
- void max_grad (const Tensor& x,
1580
- const Tensor& out,
1581
- const Tensor& out_grad,
1582
- const IntArray& axis,
1583
- bool keepdim,
1584
- bool reduce_all,
1585
- Tensor* x_grad) {
1586
- if (!x_grad) {
1587
- return ;
1588
- }
1589
-
1590
- Tensor x_grad_tmp;
1591
- if (has_dynamic_shape (x.shape ())) {
1592
- const Tensor x_shape = shape64<T>(x);
1593
- const Tensor zero_tensor =
1594
- backend::full_with_tensor<T>(x_shape, 0.0 , x.dtype (), x.place ());
1595
- const int64_t axis_size = axis.size ();
1596
- const int64_t x_dim_size = x.dims ().size ();
1597
-
1598
- reduce_all = false ;
1599
- if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
1600
- reduce_all = true ;
1601
- } else {
1602
- reduce_all = false ;
1603
- }
1604
-
1605
- if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
1606
- auto out_grad_tmp = backend::expand<T>(out_grad, x_shape);
1607
- auto out_tmp = backend::expand<T>(out, x_shape);
1608
- auto mask = equal<T>(x, out_tmp);
1609
- x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
1610
- } else {
1611
- const Tensor out_grad_shape = shape64<T>(out_grad);
1612
- auto axis_ = std::vector<int64_t >();
1613
-
1614
- if (reduce_all) {
1615
- for (int64_t i = 0 ; i < x_dim_size; i++) {
1616
- axis_.push_back (i);
1617
- }
1618
- } else {
1619
- axis_ = axis.GetData ();
1620
- for (int64_t i = 0 ; i < axis_size; i++) {
1621
- if (axis[i] < 0 ) {
1622
- axis_[i] = axis[i] + x_dim_size;
1623
- }
1624
- }
1625
- }
1626
- const Tensor out_grad_shape_extend =
1627
- get_unsqueeze_dims<T>(out_grad_shape, axis_);
1628
- auto out_grad_ = backend::reshape<T>(out_grad, out_grad_shape_extend);
1629
- auto out_ = backend::reshape<T>(out, out_grad_shape_extend);
1630
- auto out_grad_tmp = backend::expand<T>(out_grad_, x_shape);
1631
- auto out_tmp = backend::expand<T>(out_, x_shape);
1632
- auto mask = equal<T>(x, out_tmp);
1633
- x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
1634
- }
1635
- } else {
1636
- auto zero_tensor =
1637
- full<T>(common::vectorize (x.dims ()), 0.0 , x.dtype (), x.place ());
1638
- std::vector<int64_t > x_dim = common::vectorize<int64_t >(x.dims ());
1639
- int64_t axis_size = axis.size ();
1640
- int64_t x_dim_size = x_dim.size ();
1641
- reduce_all = false ;
1642
- if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
1643
- reduce_all = true ;
1644
- } else {
1645
- reduce_all = false ;
1646
- }
1647
-
1648
- if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
1649
- auto out_grad_tmp = out_grad.expand (IntArray (x_dim));
1650
- auto out_tmp = out.expand (IntArray (x_dim));
1651
- auto mask = equal<T>(x, out_tmp);
1652
- x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
1653
- } else {
1654
- auto axis_ = std::vector<int64_t >();
1655
- if (reduce_all) {
1656
- for (int64_t i = 0 ; i < x_dim_size; i++) {
1657
- axis_.push_back (i);
1658
- }
1659
- } else {
1660
- axis_ = axis.GetData ();
1661
- for (int64_t i = 0 ; i < axis_size; i++) {
1662
- if (axis[i] < 0 ) {
1663
- axis_[i] = axis[i] + x_dim_size;
1664
- }
1665
- }
1666
- }
1667
- auto out_grad_shape = get_unsqueeze_dims (out_grad, axis_);
1668
- auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
1669
- auto out_ = reshape<T>(out, out_grad_shape);
1670
- auto out_grad_tmp = out_grad_.expand (IntArray (x_dim));
1671
- auto out_tmp = out_.expand (IntArray (x_dim));
1672
- auto mask = equal<T>(x, out_tmp);
1673
- x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
1674
- }
1675
- }
1676
- set_output<T>(x_grad_tmp, x_grad);
1677
- }
1678
-
1679
1578
template <typename T>
1680
1579
void slice_grad (const Tensor& input,
1681
1580
const Tensor& out_grad,
@@ -3498,6 +3397,114 @@ void amin_grad(const Tensor& x,
3498
3397
}
3499
3398
}
3500
3399
3400
+ template <typename T>
3401
+ void max_grad (const Tensor& x,
3402
+ const Tensor& out,
3403
+ const Tensor& out_grad,
3404
+ const IntArray& axis,
3405
+ bool keepdim,
3406
+ bool reduce_all,
3407
+ Tensor* x_grad) {
3408
+ if (!x_grad) {
3409
+ return ;
3410
+ }
3411
+
3412
+ if (axis.size () == 0 ) {
3413
+ Tensor x_grad_tmp;
3414
+ amax_grad<T>(x, out, out_grad, axis, keepdim, reduce_all, &x_grad_tmp);
3415
+ set_output<T>(x_grad_tmp, x_grad);
3416
+ return ;
3417
+ }
3418
+
3419
+ Tensor x_grad_tmp;
3420
+ if (has_dynamic_shape (x.shape ())) {
3421
+ const Tensor x_shape = shape64<T>(x);
3422
+ const Tensor zero_tensor =
3423
+ backend::full_with_tensor<T>(x_shape, 0.0 , x.dtype (), x.place ());
3424
+ const int64_t axis_size = axis.size ();
3425
+ const int64_t x_dim_size = x.dims ().size ();
3426
+
3427
+ reduce_all = false ;
3428
+ if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
3429
+ reduce_all = true ;
3430
+ } else {
3431
+ reduce_all = false ;
3432
+ }
3433
+
3434
+ if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
3435
+ auto out_grad_tmp = backend::expand<T>(out_grad, x_shape);
3436
+ auto out_tmp = backend::expand<T>(out, x_shape);
3437
+ auto mask = equal<T>(x, out_tmp);
3438
+ x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
3439
+ } else {
3440
+ const Tensor out_grad_shape = shape64<T>(out_grad);
3441
+ auto axis_ = std::vector<int64_t >();
3442
+
3443
+ if (reduce_all) {
3444
+ for (int64_t i = 0 ; i < x_dim_size; i++) {
3445
+ axis_.push_back (i);
3446
+ }
3447
+ } else {
3448
+ axis_ = axis.GetData ();
3449
+ for (int64_t i = 0 ; i < axis_size; i++) {
3450
+ if (axis[i] < 0 ) {
3451
+ axis_[i] = axis[i] + x_dim_size;
3452
+ }
3453
+ }
3454
+ }
3455
+ const Tensor out_grad_shape_extend =
3456
+ get_unsqueeze_dims<T>(out_grad_shape, axis_);
3457
+ auto out_grad_ = backend::reshape<T>(out_grad, out_grad_shape_extend);
3458
+ auto out_ = backend::reshape<T>(out, out_grad_shape_extend);
3459
+ auto out_grad_tmp = backend::expand<T>(out_grad_, x_shape);
3460
+ auto out_tmp = backend::expand<T>(out_, x_shape);
3461
+ auto mask = equal<T>(x, out_tmp);
3462
+ x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
3463
+ }
3464
+ } else {
3465
+ auto zero_tensor =
3466
+ full<T>(common::vectorize (x.dims ()), 0.0 , x.dtype (), x.place ());
3467
+ std::vector<int64_t > x_dim = common::vectorize<int64_t >(x.dims ());
3468
+ int64_t axis_size = axis.size ();
3469
+ int64_t x_dim_size = x_dim.size ();
3470
+ reduce_all = false ;
3471
+ if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
3472
+ reduce_all = true ;
3473
+ } else {
3474
+ reduce_all = false ;
3475
+ }
3476
+
3477
+ if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
3478
+ auto out_grad_tmp = out_grad.expand (IntArray (x_dim));
3479
+ auto out_tmp = out.expand (IntArray (x_dim));
3480
+ auto mask = equal<T>(x, out_tmp);
3481
+ x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
3482
+ } else {
3483
+ auto axis_ = std::vector<int64_t >();
3484
+ if (reduce_all) {
3485
+ for (int64_t i = 0 ; i < x_dim_size; i++) {
3486
+ axis_.push_back (i);
3487
+ }
3488
+ } else {
3489
+ axis_ = axis.GetData ();
3490
+ for (int64_t i = 0 ; i < axis_size; i++) {
3491
+ if (axis[i] < 0 ) {
3492
+ axis_[i] = axis[i] + x_dim_size;
3493
+ }
3494
+ }
3495
+ }
3496
+ auto out_grad_shape = get_unsqueeze_dims (out_grad, axis_);
3497
+ auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
3498
+ auto out_ = reshape<T>(out, out_grad_shape);
3499
+ auto out_grad_tmp = out_grad_.expand (IntArray (x_dim));
3500
+ auto out_tmp = out_.expand (IntArray (x_dim));
3501
+ auto mask = equal<T>(x, out_tmp);
3502
+ x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
3503
+ }
3504
+ }
3505
+ set_output<T>(x_grad_tmp, x_grad);
3506
+ }
3507
+
3501
3508
template <typename T>
3502
3509
void p_norm_grad (const Tensor& x,
3503
3510
/* output of forward was reserved for efficient backward*/
0 commit comments