Skip to content

Commit 9e59d13

Browse files
committed
Implemented formatting
1 parent 9af22e7 commit 9e59d13

File tree

16 files changed

+851
-671
lines changed

16 files changed

+851
-671
lines changed

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
style = "blue"
22

33
ignore = ["src/Wrapper.jl"]
4+
pipe_to_function_call = false
45
whitespace_in_kwargs = true
56
whitespace_typedefs = true

deps/julia_wrapper_generator/generator.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ function rewrite!(e::Expr)
2222
end
2323

2424
function rewrite!(e::Expr, ::Val{:function})
25-
rewrite!(e.args[2], Val(e.args[2].head))
25+
return rewrite!(e.args[2], Val(e.args[2].head))
2626
end
2727

2828
function rewrite!(e::Expr, ::Val{:block})
29-
e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
29+
return e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
3030
end
3131

3232
function rewrite!(dag::ExprDAG)
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
function get_error()
2-
err = cglobal((:myerr, libtorch_c_api), Cstring) |> unsafe_load
3-
unsafe_string(err)
2+
err = cglobal((:myerr, libtorch_c_api), Cstring) |> unsafe_load
3+
return unsafe_string(err)
44
end
55

66
macro runtime_error_check(ex)
7-
quote
8-
x = $ex
9-
if x == 1
10-
cs = get_error()
11-
flush_error()
12-
throw(cs)
13-
end
14-
end |> esc
7+
return quote
8+
x = $ex
9+
if x == 1
10+
cs = get_error()
11+
flush_error()
12+
throw(cs)
13+
end
14+
end |> esc
1515
end

src/Torch.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,25 @@ include("statistics.jl")
3232
include("grads.jl")
3333
include("utils.jl")
3434

35-
@init @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
36-
using .Flux
35+
@init @require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
36+
using .Flux
3737

38-
function (tbn::Flux.BatchNorm)(x::Tensor)
39-
tbn.λ.(Torch.batchnorm(x, tbn.γ, tbn.β, tbn.μ, tbn.σ², 0, tbn.momentum, tbn.ϵ, 1))
40-
end
38+
function (tbn::Flux.BatchNorm)(x::Tensor)
39+
return tbn.λ.(
40+
Torch.batchnorm(x, tbn.γ, tbn.β, tbn.μ, tbn.σ², 0, tbn.momentum, tbn.ϵ, 1)
41+
)
42+
end
4143

42-
function Flux.Zygote.accum(t1::Tensor, t2::Tensor{T,N}) where {T,N}
43-
ptr = Ref(Ptr{Cvoid}())
44+
function Flux.Zygote.accum(t1::Tensor, t2::Tensor{T, N}) where {T, N}
45+
ptr = Ref(Ptr{Cvoid}())
4446

45-
Torch.Wrapper.atg_add_(ptr, t1.ptr, t2.ptr)
46-
Tensor{T,N}(ptr[], Torch.on(t1))
47-
end
47+
Torch.Wrapper.atg_add_(ptr, t1.ptr, t2.ptr)
48+
return Tensor{T, N}(ptr[], Torch.on(t1))
49+
end
4850

49-
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_copy_data))
50-
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_dim))
51-
torch(x) = Flux.fmap(to_tensor, x)
51+
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_copy_data))
52+
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_dim))
53+
torch(x) = Flux.fmap(to_tensor, x)
5254
end
5355

5456
end # module

src/broadcast.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,23 @@ using Base.Broadcast: broadcast_shape
99
# Base.BroadcastStyle(::Type{Tensor}) = TensorStyle()
1010

1111
for op in (:+, :-, :/)
12-
@eval function broadcasted(::typeof($op), t1::Tensor, t2::Tensor)
13-
$op(t1, t2)
14-
end
12+
@eval function broadcasted(::typeof($op), t1::Tensor, t2::Tensor)
13+
return $op(t1, t2)
14+
end
1515
end
1616

1717
for op in (:+, :-)
18-
@eval function broadcasted(::typeof($op), t1::Tensor, t2::TensorVector)
19-
t_ = reshape(t2, -1, 1)
20-
$op(t1, t_)
21-
end
18+
@eval function broadcasted(::typeof($op), t1::Tensor, t2::TensorVector)
19+
t_ = reshape(t2, -1, 1)
20+
return $op(t1, t_)
21+
end
2222
end
2323

24-
function broadcasted(::typeof(*), t1::Tensor{T,N}, t2::Tensor{T,M}) where {T,N,M}
25-
ptr = Ref(Ptr{Cvoid}())
24+
function broadcasted(::typeof(*), t1::Tensor{T, N}, t2::Tensor{T, M}) where {T, N, M}
25+
ptr = Ref(Ptr{Cvoid}())
2626

27-
atg_mul(ptr, t1.ptr, t2.ptr)
28-
Tensor{T,max(N,M)}(ptr[], on(t1))
27+
atg_mul(ptr, t1.ptr, t2.ptr)
28+
return Tensor{T, max(N, M)}(ptr[], on(t1))
2929
end
3030

3131
broadcasted(::typeof(NNlib.relu), t::Tensor) = NNlib.relu(t)
@@ -34,22 +34,21 @@ broadcasted(::typeof(identity), t::Tensor) = identity(t)
3434
broadcasted(::typeof(NNlib.sigmoid), t::Tensor) = NNlib.sigmoid(t)
3535

3636
for op in (:+, :-, :*, :/)
37-
@eval function broadcasted(::typeof($op), t::Tensor, args...)
38-
$op(t, args...)
39-
end
37+
@eval function broadcasted(::typeof($op), t::Tensor, args...)
38+
return $op(t, args...)
39+
end
4040
end
4141

4242
broadcasted(::typeof(sqrt), t::Tensor) = sqrt(t)
4343

44-
function broadcasted(::typeof(copy), t::Tensor{T,N}) where {T,N}
45-
t
44+
function broadcasted(::typeof(copy), t::Tensor{T, N}) where {T, N}
45+
return t
4646
end
4747

4848
@adjoint function broadcast(::typeof(NNlib.sigmoid), t::Tensor)
49-
50-
NNlib.sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
49+
return NNlib.sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
5150
end
5251

53-
@adjoint function broadcasted(::typeof(NNlib.relu), t::Tensor{T}) where T
54-
relu(t), Δ -> (nothing, ∇leaky_relu(Δ, t, zero(T)),)
52+
@adjoint function broadcasted(::typeof(NNlib.relu), t::Tensor{T}) where {T}
53+
return relu(t), Δ -> (nothing, ∇leaky_relu(Δ, t, zero(T)))
5554
end

0 commit comments

Comments
 (0)