Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions examples/gde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Flux: onehotbatch, onecold, logitcrossentropy, throttle
using Flux: @epochs
using Statistics: mean
using LightGraphs: adjacency_matrix
using CUDA

# Load the dataset
@load "data/cora_features.jld2" features
Expand All @@ -17,29 +18,30 @@ target_catg = 7
epochs = 40

# Preprocess the data and compute adjacency matrix
train_X = Matrix{Float32}(features) # dim: num_features * num_nodes
train_y = Float32.(labels) # dim: target_catg * num_nodes
adj_mat = Matrix{Float32}(adjacency_matrix(g))
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
train_y = Float32.(labels) |> gpu # dim: target_catg * num_nodes
adj_mat = Matrix{Float32}(adjacency_matrix(g)) |> gpu

# Define the Neural GDE
diffeqarray_to_array(x) = reshape(cpu(x), size(x)[1:2])
# diffeqarray_to_array(x) = reshape(cpu(x), size(x)[1:2])

# NeuralODE just needs first component to be in gpu()
node = NeuralODE(
GCNConv(adj_mat, hidden=>hidden),
gpu(GCNConv(adj_mat, hidden=>hidden)),
(0.f0, 1.f0), Tsit5(), save_everystep = false,
reltol = 1e-3, abstol = 1e-3, save_start = false
)

model = Chain(GCNConv(adj_mat, num_features=>hidden, relu),
Dropout(0.5),
node,
diffeqarray_to_array,
arr -> arr[1],
GCNConv(adj_mat, hidden=>target_catg),
softmax)
softmax) |> gpu

# Loss
loss(x, y) = logitcrossentropy(model(x), y)
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
accuracy(x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))

# Training
## Model Parameters
Expand Down