Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/inits/inits_reservoir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2533,3 +2533,43 @@ for initializer in (
end
end
end



function wigner_init(
rng::AbstractRNG, ::Type{T}, dims::Integer...;
radius::Number = T(1.0), std::Number = T(1.0)
) where {T <: Number}
# 1. Dimension check : Reservoir has to be a square matrix
check_res_size(dims...)
res_size = dims[1]

# 2. Initialise the empty matrix using the requested type T
W = zeros(T, res_size, res_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to avoid single letter naming if possible W -> reservoir_matrix. follow the conventions of the other initializers so that it makes it easier to search for stuff in the file

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


# 3. Populating the matrix
# Diagonal elements sampled from N(0, 2*std^2)
# Off-diagonal elements sampled from N(0, std^2)
for i in 1 : res_size
for j in 1 : res_size
if i==j
# Diagonal element
W[i, j] = randn(rng, T) * T(sqrt(2)) * T(std)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

randn(rng, T) sohuld be taken from DeviceAgnostic

also I don't understand why T(sqrt(2))

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the standard Wigner matrix, T(sqrt(2)) is for the diagonal elements as they are sampled from N(0, 2*variance), which implies sqrt(2) * std. I will modify the arguments to accept two arbitrary std values for the diagonal and off-diagonal distributions, as described in the issue. Modified rand -> DeviceAgnostic.rand

else
# Off-diagonal elements (upper triangular part)
W[i, j] = randn(rng, T) * T(std)
end
end
end

# Make the matrix symmetric
W = Symmetric(W)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this as a kwarg option that the user can choose return_symmetric::Bool = false

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to import this from LinearAlgebra here

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the return_symmetric argument, but I've kept the default case as true, which the user can set to false.
using LinearAlgebra: eigvals, I, qr, Diagonal, diag, mul!
Should I paste the above line into the inits_reservoir.jl code?


# 4. Scaling the spectral radius to the user-specified value
W = scale_radius!(W, T(radius))

# 5. Check for NaN or Inf values
check_inf_nan(W)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary, this issue usually arises from the sparsity

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I remove that line ? Or add some other check?


return W
end
14 changes: 14 additions & 0 deletions test/test_inits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,17 @@ end
@test sort(unique(dl)) == Float32.([-0.1, 0.1])
end
end

@testset "Wigner_matrix_init.jl" begin
#rng = Random.default_rng()
#res_size = 10


W = wigner_init(rng, Float32, res_size; radius = 0.9, std = 0.5)

@test size(W) == (res_size, res_size)

@test eltype(W) == Float32

@test issymmetric(W)
end