@@ -9,23 +9,23 @@ using Base.Broadcast: broadcast_shape
99# Base.BroadcastStyle(::Type{Tensor}) = TensorStyle()
1010
1111for op in (:+ , :- , :/ )
12- @eval function broadcasted(:: typeof ($ op), t1:: Tensor , t2:: Tensor )
13- $ op(t1, t2)
14- end
12+ @eval function broadcasted(:: typeof ($ op), t1:: Tensor , t2:: Tensor )
13+ return $ op(t1, t2)
14+ end
1515end
1616
1717for op in (:+ , :- )
18- @eval function broadcasted(:: typeof ($ op), t1:: Tensor , t2:: TensorVector )
19- t_ = reshape(t2, - 1 , 1 )
20- $ op(t1, t_)
21- end
18+ @eval function broadcasted(:: typeof ($ op), t1:: Tensor , t2:: TensorVector )
19+ t_ = reshape(t2, - 1 , 1 )
20+ return $ op(t1, t_)
21+ end
2222end
2323
24- function broadcasted(:: typeof (* ), t1:: Tensor{T,N} , t2:: Tensor{T,M} ) where {T,N, M}
25- ptr = Ref(Ptr{Cvoid}())
24+ function broadcasted(:: typeof (* ), t1:: Tensor{T, N} , t2:: Tensor{T, M} ) where {T, N, M}
25+ ptr = Ref(Ptr{Cvoid}())
2626
27- atg_mul(ptr, t1. ptr, t2. ptr)
28- Tensor{T,max(N,M)}(ptr[], on(t1))
27+ atg_mul(ptr, t1. ptr, t2. ptr)
28+ return Tensor{T, max(N, M)}(ptr[], on(t1))
2929end
3030
3131broadcasted(:: typeof (NNlib. relu), t:: Tensor ) = NNlib. relu(t)
@@ -34,22 +34,21 @@ broadcasted(::typeof(identity), t::Tensor) = identity(t)
3434broadcasted(:: typeof (NNlib. sigmoid), t:: Tensor ) = NNlib. sigmoid(t)
3535
3636for op in (:+ , :- , :* , :/ )
37- @eval function broadcasted(:: typeof ($ op), t:: Tensor , args... )
38- $ op(t, args... )
39- end
37+ @eval function broadcasted(:: typeof ($ op), t:: Tensor , args... )
38+ return $ op(t, args... )
39+ end
4040end
4141
4242broadcasted(:: typeof (sqrt), t:: Tensor ) = sqrt(t)
4343
44- function broadcasted(:: typeof (copy), t:: Tensor{T,N} ) where {T,N}
45- t
44+ function broadcasted(:: typeof (copy), t:: Tensor{T, N} ) where {T, N}
45+ return t
4646end
4747
4848@adjoint function broadcast(:: typeof (NNlib. sigmoid), t:: Tensor )
49-
50- NNlib. sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
49+ return NNlib. sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
5150end
5251
53- @adjoint function broadcasted(:: typeof (NNlib. relu), t:: Tensor{T} ) where T
54- relu(t), Δ -> (nothing , ∇leaky_relu(Δ, t, zero(T)), )
52+ @adjoint function broadcasted(:: typeof (NNlib. relu), t:: Tensor{T} ) where {T}
53+ return relu(t), Δ -> (nothing , ∇leaky_relu(Δ, t, zero(T)))
5554end
0 commit comments