Skip to content

Commit 3b1d07b

Browse files
authored
Merge pull request #67 from a-vartenkov/mpback_stdp
STDP implementation in the multhreaded backend.
2 parents a9480a5 + 436926f commit 3b1d07b

12 files changed

Lines changed: 283 additions & 36 deletions

File tree

examples/mnist-learn/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ int main(int argc, char **argv)
4444

4545
// Defines path to backend, on which to run a network.
4646
std::filesystem::path path_to_backend =
47-
std::filesystem::path(argv[0]).parent_path() / "knp-cpu-single-threaded-backend";
47+
std::filesystem::path(argv[0]).parent_path() / "knp-cpu-multi-threaded-backend";
4848

4949
// Read data from corresponding files.
5050
auto spike_frames = read_spike_frames(argv[1]);

knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/additive_stdp_impl.h

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <spdlog/spdlog.h>
3030

31+
#include <algorithm>
3132
#include <unordered_map>
3233
#include <utility>
3334
#include <vector>
@@ -152,11 +153,12 @@ constexpr bool is_additive_stdp_synapse()
152153

153154

154155
template <class DeltaLikeSynapse>
155-
void register_additive_stdp_spikes(
156+
void register_additive_stdp_spikes_part(
156157
knp::core::Projection<knp::synapse_traits::STDP<knp::synapse_traits::STDPAdditiveRule, DeltaLikeSynapse>>
157158
&projection,
158-
std::vector<SpikeMessage> &all_messages)
159+
std::vector<SpikeMessage> &all_messages, uint64_t part_start, uint64_t part_end)
159160
{
161+
if (part_start != 0) return; // Not much sense to parallelize this by projection, so it's calculated just once.
160162
SPDLOG_DEBUG("Calculating additive STDP delta synapse projection...");
161163

162164
using ProjectionType = typename std::decay_t<decltype(projection)>;
@@ -209,13 +211,15 @@ void register_additive_stdp_spikes(
209211

210212

211213
template <class DeltaLikeSynapse>
212-
void update_projection_weights_additive_stdp(
214+
void update_projection_weights_additive_stdp_part(
213215
knp::core::Projection<knp::synapse_traits::STDP<knp::synapse_traits::STDPAdditiveRule, DeltaLikeSynapse>>
214-
&projection)
216+
&projection,
217+
uint64_t part_start, uint64_t part_end)
215218
{
216219
// Update projection parameters.
217-
for (auto &proj : projection)
220+
for (uint64_t i = part_start; i < std::min(projection.size(), part_end); ++i)
218221
{
222+
auto &proj = projection[i];
219223
SPDLOG_TRACE("Applying STDP rule...");
220224
auto &rule = std::get<knp::core::synapse_data>(proj).rule_;
221225
const auto period = rule.tau_plus_ + rule.tau_minus_;
@@ -238,19 +242,32 @@ template <class DeltaLikeSynapse>
238242
struct WeightUpdateSTDP<synapse_traits::STDP<synapse_traits::STDPAdditiveRule, DeltaLikeSynapse>>
239243
{
240244
using Synapse = synapse_traits::STDP<synapse_traits::STDPAdditiveRule, DeltaLikeSynapse>;
241-
static void init_projection(
242-
knp::core::Projection<Synapse> &projection, std::vector<SpikeMessage> &all_messages, knp::core::Step step)
245+
246+
static void init_projection_part(
247+
knp::core::Projection<Synapse> &projection, std::vector<SpikeMessage> &all_messages, knp::core::Step step,
248+
uint64_t part_start, uint64_t part_end)
243249
{
244-
register_additive_stdp_spikes(projection, all_messages);
250+
register_additive_stdp_spikes_part(projection, all_messages, part_start, part_end);
245251
}
246252

247253
static void init_synapse(const knp::synapse_traits::synapse_parameters<Synapse> &projection, knp::core::Step step)
248254
{
249255
}
250256

257+
static void modify_weights_part(knp::core::Projection<Synapse> &projection, uint64_t part_start, uint64_t part_end)
258+
{
259+
update_projection_weights_additive_stdp_part(projection, part_start, part_end);
260+
}
261+
262+
static void init_projection(
263+
knp::core::Projection<Synapse> &projection, std::vector<SpikeMessage> &all_messages, knp::core::Step step)
264+
{
265+
init_projection_part(projection, all_messages, step, 0, projection.size());
266+
}
267+
251268
static void modify_weights(knp::core::Projection<Synapse> &projection)
252269
{
253-
update_projection_weights_additive_stdp(projection);
270+
modify_weights_part(projection, 0, projection.size());
254271
}
255272
};
256273

knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/base_stdp_impl.h

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,56 @@
44
* @kaspersky_support A. Vartenkov
55
* @date 04.11.2023
66
* @license Apache 2.0
7-
* @copyright © 2024 AO Kaspersky Lab
8-
*
9-
* Licensed under the Apache License, Version 2.0 (the "License");
10-
* you may not use this file except in compliance with the License.
11-
* You may obtain a copy of the License at
12-
*
13-
* http://www.apache.org/licenses/LICENSE-2.0
14-
*
15-
* Unless required by applicable law or agreed to in writing, software
16-
* distributed under the License is distributed on an "AS IS" BASIS,
17-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18-
* See the License for the specific language governing permissions and
7+
* @copyright © 2024 AO Kaspersky Lab
8+
*
9+
* Licensed under the Apache License, Version 2.0 (the "License");
10+
* you may not use this file except in compliance with the License.
11+
* You may obtain a copy of the License at
12+
*
13+
* http://www.apache.org/licenses/LICENSE-2.0
14+
*
15+
* Unless required by applicable law or agreed to in writing, software
16+
* distributed under the License is distributed on an "AS IS" BASIS,
17+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
* See the License for the specific language governing permissions and
1919
* limitations under the License.
2020
*/
2121
#pragma once
2222
#include <knp/core/messaging/messaging.h>
2323
#include <knp/core/projection.h>
2424

25+
#include <unordered_map>
2526
#include <vector>
2627

2728
/**
2829
* @brief Namespace for CPU backends.
2930
*/
3031
namespace knp::backends::cpu
3132
{
32-
3333
template <class DeltaLikeSynapse>
3434
struct WeightUpdateSTDP
3535
{
36-
static void init_projection(
36+
static void init_projection_part(
3737
const knp::core::Projection<DeltaLikeSynapse> &projection,
38-
const std::vector<core::messaging::SpikeMessage> &messages, uint64_t step)
38+
const std::vector<core::messaging::SpikeMessage> &messages, uint64_t part_begin, uint64_t part_end,
39+
uint64_t step)
3940
{
4041
}
4142

42-
static void init_synapse(const knp::synapse_traits::synapse_parameters<DeltaLikeSynapse> &projection, uint64_t step)
43+
static void init_synapse(knp::synapse_traits::synapse_parameters<DeltaLikeSynapse> &params, uint64_t step) {}
44+
45+
static void modify_weights_part(
46+
const knp::core::Projection<DeltaLikeSynapse> &projection, uint64_t part_begin, uint64_t part_end)
4347
{
4448
}
4549

50+
static void init_projection(
51+
const knp::core::Projection<DeltaLikeSynapse> &projection,
52+
const std::vector<core::messaging::SpikeMessage> &messages, uint64_t step)
53+
{
54+
}
55+
56+
4657
static void modify_weights(const knp::core::Projection<DeltaLikeSynapse> &projection) {}
4758
};
4859

knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/blifat_population_impl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,17 @@ void calculate_neurons_post_input_state_part(
314314
}
315315

316316

317+
/**
318+
* Anything that works after the population has been calculated goes here. Mostly for STDP and other learning.
319+
*/
320+
template <class BlifatLikeNeuron, class ProjectionContainer>
321+
void finalize_population(
322+
knp::core::Population<BlifatLikeNeuron> &population, const knp::core::messaging::SpikeMessage &message,
323+
ProjectionContainer &projections, knp::core::Step step)
324+
{
325+
}
326+
327+
317328
/**
318329
* @brief Process BLIFAT neuron population and return spiked neuron indexes.
319330
* @tparam BlifatLikeNeuron type of neuron which inference can be calculated the same as BLIFAT.

knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/delta_synapse_projection_impl.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <knp/core/message_bus.h>
2525
#include <knp/core/projection.h>
2626
#include <knp/synapse-traits/delta.h>
27+
#include <knp/synapse-traits/stdp_synaptic_resource_rule.h>
2728

2829
#include <spdlog/spdlog.h>
2930

@@ -132,15 +133,15 @@ template <class DeltaLikeSynapse>
132133
void calculate_projection_part_impl(
133134
knp::core::Projection<DeltaLikeSynapse> &projection,
134135
const std::unordered_map<knp::core::Step, size_t> &message_in_data, MessageQueue &future_messages, uint64_t step_n,
135-
size_t part_start, size_t part_size, std::mutex &mutex)
136+
uint64_t part_start, uint64_t part_size, std::mutex &mutex)
136137
{
137138
size_t part_end = std::min(part_start + part_size, projection.size());
138139
std::vector<std::pair<uint64_t, knp::core::messaging::SynapticImpact>> container;
140+
WeightUpdateStdpMp<DeltaLikeSynapse>::init_projection_part(projection, message_in_data, step_n);
139141
for (size_t synapse_index = part_start; synapse_index < part_end; ++synapse_index)
140142
{
141143
auto &synapse = projection[synapse_index];
142144
// update_step(synapse.params_, step_n);
143-
// TODO: Move update logic here too.
144145
auto iter = message_in_data.find(std::get<core::source_neuron_id>(synapse));
145146
if (iter == message_in_data.end())
146147
{
@@ -150,6 +151,7 @@ void calculate_projection_part_impl(
150151
// Add new impact.
151152
// The message is sent on step N - 1, received on step N.
152153
uint64_t key = std::get<core::synapse_data>(synapse).delay_ + step_n - 1;
154+
WeightUpdateStdpMp<DeltaLikeSynapse>::init_synapse(std::get<core::synapse_data>(synapse), step_n);
153155

154156
knp::core::messaging::SynapticImpact impact{
155157
synapse_index, std::get<core::synapse_data>(synapse).weight_ * iter->second,
@@ -183,6 +185,7 @@ void calculate_projection_part_impl(
183185
future_messages.insert(std::make_pair(value.first, message_out));
184186
}
185187
}
188+
WeightUpdateStdpMp<DeltaLikeSynapse>::modify_weights_part(projection);
186189
}
187190

188191

knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
#include <algorithm>
3333
#include <limits>
3434
#include <numeric>
35+
#include <unordered_map>
3536
#include <utility>
37+
#include <variant>
3638
#include <vector>
3739

3840
#include <boost/mp11.hpp>
@@ -381,6 +383,40 @@ struct WeightUpdateSTDP<synapse_traits::STDP<synapse_traits::STDPSynapticResourc
381383
};
382384

383385

386+
template <class DeltaLikeSynapse>
387+
struct WeightUpdateStdpMp
388+
{
389+
using Synapse = DeltaLikeSynapse;
390+
static void init_projection_part(
391+
const knp::core::Projection<Synapse> &projection,
392+
const std::unordered_map<knp::core::Step, size_t> &message_data, uint64_t step)
393+
{
394+
}
395+
396+
static void init_synapse(knp::synapse_traits::synapse_parameters<Synapse> &params, uint64_t step) {}
397+
398+
static void modify_weights_part(const knp::core::Projection<Synapse> &projection) {}
399+
};
400+
401+
402+
template <class DeltaLikeSynapse>
403+
struct WeightUpdateStdpMp<synapse_traits::STDP<synapse_traits::STDPSynapticResourceRule, DeltaLikeSynapse>>
404+
{
405+
using Synapse = synapse_traits::STDP<synapse_traits::STDPSynapticResourceRule, DeltaLikeSynapse>;
406+
static void init_projection_part(
407+
const knp::core::Projection<Synapse> &projection,
408+
const std::unordered_map<knp::core::Step, size_t> &message_data, uint64_t step)
409+
{
410+
}
411+
412+
static void init_synapse(knp::synapse_traits::synapse_parameters<Synapse> &params, uint64_t step)
413+
{
414+
params.rule_.last_spike_step_ = step;
415+
}
416+
417+
static void modify_weights_part(const knp::core::Projection<Synapse> &projection) {}
418+
};
419+
384420
template <class NeuronType, class SynapseType>
385421
void do_STDP_resource_plasticity(
386422
knp::core::Population<knp::neuron_traits::SynapticResourceSTDPNeuron<NeuronType>> &population,
@@ -402,4 +438,6 @@ void do_STDP_resource_plasticity(
402438
// 3. Renormalize resources if needed.
403439
knp::backends::cpu::renormalize_resource(working_projections, population, step);
404440
}
441+
442+
405443
} // namespace knp::backends::cpu

knp/backends/cpu/cpu-multi-threaded-backend/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ knp_add_library("${PROJECT_NAME}"
4848
BOTH
4949
impl/backend.cpp
5050
impl/get_network.cpp
51+
impl/template_specs.cpp
5152
${${PROJECT_NAME}_headers}
5253
ALIAS KNP::Backends::CPUMultiThreaded
5354
LINK_PRIVATE

knp/backends/cpu/cpu-multi-threaded-backend/impl/backend.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
#include <knp/backends/cpu-library/blifat_population.h>
2323
#include <knp/backends/cpu-library/delta_synapse_projection.h>
24+
#include <knp/backends/cpu-library/impl/altai_lif_population_impl.h>
25+
#include <knp/backends/cpu-library/impl/blifat_population_impl.h>
26+
#include <knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h>
2427
#include <knp/backends/cpu-library/init.h>
2528
#include <knp/backends/cpu-multi-threaded/backend.h>
2629
#include <knp/backends/thread_pool/thread_pool.h>
@@ -33,6 +36,7 @@
3336

3437
#include <functional>
3538
#include <optional>
39+
#include <unordered_map>
3640
#include <vector>
3741

3842
#include <boost/mp11.hpp>
@@ -153,10 +157,28 @@ std::vector<knp::core::messaging::SpikeMessage> MultiThreadedCPUBackend::calcula
153157
std::ref(pop), std::ref(message), neuron_index, population_part_size_, std::ref(ep_mutex_));
154158
},
155159
population);
160+
}
161+
}
162+
calc_pool_->join();
163+
for (size_t pop_id = 0; pop_id < populations_.size(); ++pop_id)
164+
{
165+
auto &message = spike_container[pop_id];
166+
std::visit(
167+
[this, &message](auto &pop)
168+
{
169+
using T = std::decay_t<decltype(pop)>;
170+
auto call_finalize = [](T &pop_ref, knp::core::messaging::SpikeMessage &message_ref,
171+
ProjectionContainer &proj_ref, knp::core::Step step)
172+
{
173+
knp::backends::cpu::finalize_population<typename T::PopulationNeuronType, ProjectionContainer>(
174+
pop_ref, message_ref, proj_ref, step);
175+
};
176+
calc_pool_->post(call_finalize, std::ref(pop), std::ref(message), std::ref(projections_), get_step());
177+
},
178+
populations_[pop_id]);
156179
#if defined(_MSC_VER)
157180
# pragma warning(pop)
158181
#endif
159-
}
160182
}
161183
calc_pool_->join();
162184
return spike_container;
@@ -394,5 +416,4 @@ MultiThreadedCPUBackend::ProjectionConstIterator MultiThreadedCPUBackend::end_pr
394416

395417

396418
BOOST_DLL_ALIAS(knp::backends::multi_threaded_cpu::MultiThreadedCPUBackend::create, create_knp_backend)
397-
398419
} // namespace knp::backends::multi_threaded_cpu

0 commit comments

Comments
 (0)