Skip to content

Convert node_id to Node only when necessary #2606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libnestutil/dict_util.h
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ updateValueParam( DictionaryDatum const& d, Name const n, VT& value, nest::Node*
auto vp = kernel().vp_manager.node_id_to_vp( node->get_node_id() );
auto tid = kernel().vp_manager.vp_to_thread( vp );
auto rng = get_vp_specific_rng( tid );
value = pd->get()->value( rng, node );
value = pd->get()->value( rng, node->get_node_id() );
return true;
}
else
61 changes: 30 additions & 31 deletions nestkernel/conn_builder.cpp
Original file line number Diff line number Diff line change
@@ -292,7 +292,7 @@ nest::ConnBuilder::disconnect()

void
nest::ConnBuilder::update_param_dict_( index snode_id,
Node& target,
index target,
thread target_thread,
RngPtr rng,
index synapse_indx )
@@ -306,63 +306,63 @@ nest::ConnBuilder::update_param_dict_( index snode_id,
// change value of dictionary entry without allocating new datum
IntegerDatum* id = static_cast< IntegerDatum* >(
( ( *param_dicts_[ synapse_indx ][ target_thread ] )[ synapse_parameter.first ] ).datum() );
( *id ) = synapse_parameter.second->value_int( target_thread, rng, snode_id, &target );
( *id ) = synapse_parameter.second->value_int( target_thread, rng, snode_id, target );
}
else
{
// change value of dictionary entry without allocating new datum
DoubleDatum* dd = static_cast< DoubleDatum* >(
( ( *param_dicts_[ synapse_indx ][ target_thread ] )[ synapse_parameter.first ] ).datum() );
( *dd ) = synapse_parameter.second->value_double( target_thread, rng, snode_id, &target );
( *dd ) = synapse_parameter.second->value_double( target_thread, rng, snode_id, target );
}
}
}

void
nest::ConnBuilder::single_connect_( index snode_id, Node& target, thread target_thread, RngPtr rng )
nest::ConnBuilder::single_connect_( index snode_id, index target_id, thread target_thread, RngPtr rng )
{
if ( this->requires_proxies() and not target.has_proxies() )
if ( this->requires_proxies() and not kernel().node_manager.has_proxy( target_id ) )
{
throw IllegalConnection( "Cannot use this rule to connect to nodes without proxies (usually devices)." );
}

for ( size_t synapse_indx = 0; synapse_indx < synapse_params_.size(); ++synapse_indx )
{
update_param_dict_( snode_id, target, target_thread, rng, synapse_indx );
update_param_dict_( snode_id, target_id, target_thread, rng, synapse_indx );

if ( default_weight_and_delay_[ synapse_indx ] )
{
kernel().connection_manager.connect( snode_id,
&target,
target_id,
target_thread,
synapse_model_id_[ synapse_indx ],
param_dicts_[ synapse_indx ][ target_thread ] );
}
else if ( default_weight_[ synapse_indx ] )
{
kernel().connection_manager.connect( snode_id,
&target,
target_id,
target_thread,
synapse_model_id_[ synapse_indx ],
param_dicts_[ synapse_indx ][ target_thread ],
delays_[ synapse_indx ]->value_double( target_thread, rng, snode_id, &target ) );
delays_[ synapse_indx ]->value_double( target_thread, rng, snode_id, target_id ) );
}
else if ( default_delay_[ synapse_indx ] )
{
kernel().connection_manager.connect( snode_id,
&target,
target_id,
target_thread,
synapse_model_id_[ synapse_indx ],
param_dicts_[ synapse_indx ][ target_thread ],
numerics::nan,
weights_[ synapse_indx ]->value_double( target_thread, rng, snode_id, &target ) );
weights_[ synapse_indx ]->value_double( target_thread, rng, snode_id, target_id ) );
}
else
{
const double delay = delays_[ synapse_indx ]->value_double( target_thread, rng, snode_id, &target );
const double weight = weights_[ synapse_indx ]->value_double( target_thread, rng, snode_id, &target );
const double delay = delays_[ synapse_indx ]->value_double( target_thread, rng, snode_id, target_id );
const double weight = weights_[ synapse_indx ]->value_double( target_thread, rng, snode_id, target_id );
kernel().connection_manager.connect( snode_id,
&target,
target_id,
target_thread,
synapse_model_id_[ synapse_indx ],
param_dicts_[ synapse_indx ][ target_thread ],
@@ -644,7 +644,7 @@ nest::OneToOneBuilder::connect_()
continue;
}

single_connect_( snode_id, *target, tid, rng );
single_connect_( snode_id, tnode_id, tid, rng );
}
}
else
@@ -670,7 +670,7 @@ nest::OneToOneBuilder::connect_()
// as we iterate only over local nodes
continue;
}
single_connect_( snode_id, *target, tid, rng );
single_connect_( snode_id, tnode_id, tid, rng );
}
}
}
@@ -777,7 +777,7 @@ nest::OneToOneBuilder::sp_connect_()
Node* const target = kernel().node_manager.get_node_or_proxy( tnode_id, tid );
const thread target_thread = target->get_thread();

single_connect_( snode_id, *target, target_thread, rng );
single_connect_( snode_id, tnode_id, target_thread, rng );
}
}
catch ( std::exception& err )
@@ -920,7 +920,7 @@ nest::AllToAllBuilder::inner_connect_( const int tid, RngPtr rng, Node* target,
continue;
}

single_connect_( snode_id, *target, target_thread, rng );
single_connect_( snode_id, tnode_id, target_thread, rng );
}
}

@@ -963,7 +963,7 @@ nest::AllToAllBuilder::sp_connect_()
}
Node* const target = kernel().node_manager.get_node_or_proxy( tnode_id, tid );
const thread target_thread = target->get_thread();
single_connect_( snode_id, *target, target_thread, rng );
single_connect_( snode_id, tnode_id, target_thread, rng );
}
}
}
@@ -1153,7 +1153,7 @@ nest::FixedInDegreeBuilder::connect_()
const index tnode_id = ( *target_it ).node_id;
Node* const target = kernel().node_manager.get_node_or_proxy( tnode_id, tid );

const long indegree_value = std::round( indegree_->value( rng, target ) );
const long indegree_value = std::round( indegree_->value( rng, tnode_id ) );
if ( target->is_proxy() )
{
// skip array parameters handled in other virtual processes
@@ -1178,7 +1178,7 @@ nest::FixedInDegreeBuilder::connect_()
continue;
}
auto source = n->get_node();
const long indegree_value = std::round( indegree_->value( rng, source ) );
const long indegree_value = std::round( indegree_->value( rng, n->get_node_id() ) );

inner_connect_( tid, rng, source, tnode_id, false, indegree_value );
}
@@ -1237,7 +1237,7 @@ nest::FixedInDegreeBuilder::inner_connect_( const int tid,
ch_ids.insert( s_id );
}

single_connect_( snode_id, *target, target_thread, rng );
single_connect_( snode_id, tnode_id, target_thread, rng );
}
}

@@ -1313,8 +1313,7 @@ nest::FixedOutDegreeBuilder::connect_()
std::vector< index > tgt_ids_;
const long n_rnd = targets_->size();

Node* source_node = kernel().node_manager.get_node_or_proxy( snode_id );
const long outdegree_value = std::round( outdegree_->value( grng, source_node ) );
const long outdegree_value = std::round( outdegree_->value( grng, snode_id ) );
for ( long j = 0; j < outdegree_value; ++j )
{
unsigned long t_id;
@@ -1358,7 +1357,7 @@ nest::FixedOutDegreeBuilder::connect_()
continue;
}

single_connect_( snode_id, *target, tid, rng );
single_connect_( snode_id, *tnode_id_it, tid, rng );
}
}
catch ( std::exception& err )
@@ -1525,7 +1524,7 @@ nest::FixedTotalNumberBuilder::connect_()

if ( allow_autapses_ or snode_id != tnode_id )
{
single_connect_( snode_id, *target, target_thread, rng );
single_connect_( snode_id, tnode_id, target_thread, rng );
num_conns_on_vp[ vp_id ]--;
}
}
@@ -1646,12 +1645,12 @@ nest::BernoulliBuilder::inner_connect_( const int tid, RngPtr rng, Node* target,
{
continue;
}
if ( rng->drand() >= p_->value( rng, target ) )
if ( rng->drand() >= p_->value( rng, tnode_id ) )
{
continue;
}

single_connect_( snode_id, *target, target_thread, rng );
single_connect_( snode_id, tnode_id, target_thread, rng );
}
}

@@ -1759,14 +1758,14 @@ nest::SymmetricBernoulliBuilder::connect_()
if ( target_thread == tid )
{
assert( target );
single_connect_( snode_id, *target, target_thread, synced_rng );
single_connect_( snode_id, ( *tnode_id ).node_id, target_thread, synced_rng );
}

// if source is local: connect
if ( source_thread == tid )
{
assert( source );
single_connect_( ( *tnode_id ).node_id, *source, source_thread, synced_rng );
single_connect_( ( *tnode_id ).node_id, snode_id, source_thread, synced_rng );
}

++i;
@@ -1883,7 +1882,7 @@ nest::SPBuilder::connect_( const std::vector< index >& sources, const std::vecto
}
Node* const target = kernel().node_manager.get_node_or_proxy( *tnode_id_it, tid );

single_connect_( *snode_id_it, *target, tid, rng );
single_connect_( *snode_id_it, *tnode_id_it, tid, rng );
}
}
catch ( std::exception& err )
4 changes: 2 additions & 2 deletions nestkernel/conn_builder.h
Original file line number Diff line number Diff line change
@@ -145,10 +145,10 @@ class ConnBuilder
throw NotImplemented( "This connection rule is not implemented for structural plasticity." );
}

void update_param_dict_( index snode_id, Node& target, thread target_thread, RngPtr rng, index indx );
void update_param_dict_( index snode_id, index target, thread target_thread, RngPtr rng, index indx );

//! Create connection between given nodes, fill parameter values
void single_connect_( index, Node&, thread, RngPtr );
void single_connect_( index, index, thread, RngPtr );
void single_disconnect_( index, Node&, thread );

/**
24 changes: 16 additions & 8 deletions nestkernel/conn_builder_conngen.cpp
Original file line number Diff line number Diff line change
@@ -86,9 +86,14 @@ ConnectionGeneratorBuilder::connect_()
{
// No need to check for locality of the target, as the mask
// created by cg_set_masks() only contains local nodes.
Node* const target_node = kernel().node_manager.get_node_or_proxy( ( *targets_ )[ target ] );
const thread target_thread = target_node->get_thread();
single_connect_( ( *sources_ )[ source ], *target_node, target_thread, rng );
// Ayssar! TODO: then why calling `get_node_or_proxy`?
const vp = kernel().vp_manager.node_id_to_vp( target );
const thread target_thread = 0;
if ( kernel().vp_manager.is_local_vp( vp ) )
{
target_thread = kernel().vp_manager.vp_to_thread( vp );
}
single_connect_( ( *sources_ )[ source ], target, target_thread, rng );
}
}
else if ( num_parameters == 2 )
@@ -115,14 +120,17 @@ ConnectionGeneratorBuilder::connect_()
{
// No need to check for locality of the target node, as the mask
// created by cg_set_masks() only contains local nodes.
Node* target_node = kernel().node_manager.get_node_or_proxy( ( *targets_ )[ target ] );
const thread target_thread = target_node->get_thread();

update_param_dict_( ( *sources_ )[ source ], *target_node, target_thread, rng, 0 );
const vp = kernel().vp_manager.node_id_to_vp( target );
const thread target_thread = 0;
if ( kernel().vp_manager.is_local_vp( vp ) )
{
target_thread = kernel().vp_manager.vp_to_thread( vp );
}
update_param_dict_( ( *sources_ )[ source ], target, target_thread, rng, 0 );

// Use the low-level connect() here, as we need to pass a custom weight and delay
kernel().connection_manager.connect( ( *sources_ )[ source ],
target_node,
target,
target_thread,
synapse_model_id_[ 0 ],
param_dicts_[ 0 ][ target_thread ],
2 changes: 1 addition & 1 deletion nestkernel/conn_parameter.cpp
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ nest::ParameterConnParameterWrapper::ParameterConnParameterWrapper( const Parame
}

double
nest::ParameterConnParameterWrapper::value_double( thread, RngPtr rng, index, Node* target ) const
nest::ParameterConnParameterWrapper::value_double( thread, RngPtr rng, index, index target ) const
{
return parameter_->value( rng, target );
}
24 changes: 12 additions & 12 deletions nestkernel/conn_parameter.h
Original file line number Diff line number Diff line change
@@ -73,8 +73,8 @@ class ConnParameter
* @param rng random number generator pointer
* will be ignored except for random parameters.
*/
virtual double value_double( thread, RngPtr, index, Node* ) const = 0;
virtual long value_int( thread, RngPtr, index, Node* ) const = 0;
virtual double value_double( thread, RngPtr, index, index ) const = 0;
virtual long value_int( thread, RngPtr, index, index ) const = 0;
virtual void
skip( thread, size_t ) const
{
@@ -134,13 +134,13 @@ class ScalarDoubleParameter : public ConnParameter
}

double
value_double( thread, RngPtr, index, Node* ) const override
value_double( thread, RngPtr, index, index ) const override
{
return value_;
}

long
value_int( thread, RngPtr, index, Node* ) const override
value_int( thread, RngPtr, index, index ) const override
{
throw KernelException( "ConnParameter calls value function with false return type." );
}
@@ -180,13 +180,13 @@ class ScalarIntegerParameter : public ConnParameter
}

double
value_double( thread, RngPtr, index, Node* ) const override
value_double( thread, RngPtr, index, index ) const override
{
return static_cast< double >( value_ );
}

long
value_int( thread, RngPtr, index, Node* ) const override
value_int( thread, RngPtr, index, index ) const override
{
return value_;
}
@@ -263,7 +263,7 @@ class ArrayDoubleParameter : public ConnParameter
}

double
value_double( thread tid, RngPtr, index, Node* ) const override
value_double( thread tid, RngPtr, index, index ) const override
{
if ( next_[ tid ] != values_->end() )
{
@@ -276,7 +276,7 @@ class ArrayDoubleParameter : public ConnParameter
}

long
value_int( thread, RngPtr, index, Node* ) const override
value_int( thread, RngPtr, index, index ) const override
{
throw KernelException( "ConnParameter calls value function with false return type." );
}
@@ -345,7 +345,7 @@ class ArrayIntegerParameter : public ConnParameter
}

long
value_int( thread tid, RngPtr, index, Node* ) const override
value_int( thread tid, RngPtr, index, index ) const override
{
if ( next_[ tid ] != values_->end() )
{
@@ -358,7 +358,7 @@ class ArrayIntegerParameter : public ConnParameter
}

double
value_double( thread tid, RngPtr, index, Node* ) const override
value_double( thread tid, RngPtr, index, index ) const override
{
if ( next_[ tid ] != values_->end() )
{
@@ -401,10 +401,10 @@ class ParameterConnParameterWrapper : public ConnParameter
public:
ParameterConnParameterWrapper( const ParameterDatum&, const size_t );

double value_double( thread target_thread, RngPtr rng, index snode_id, Node* target ) const override;
double value_double( thread target_thread, RngPtr rng, index snode_id, index target ) const override;

long
value_int( thread target_thread, RngPtr rng, index snode_id, Node* target ) const override
value_int( thread target_thread, RngPtr rng, index snode_id, index target ) const override
{
return value_double( target_thread, rng, snode_id, target );
}
Loading