diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index fd045c3ee..fa875f824 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -554,7 +554,7 @@ void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, PresynapticUpdat // Write sum of presynaptic output to global memory if(isPresynapticOutputRequired(sg, trueSpike)) { - groupEnv.printLine(backend.getAtomic(sg.getScalarType()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], lOutPre);"); + synEnv.printLine(backend.getAtomic(sg.getScalarType()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], lOutPre);"); } } diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index c2590757c..ea90417c7 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -249,19 +249,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Check that there are call arguments on the stack assert(!m_CallArguments.empty()); - // Loop through call arguments on top of stack - size_t i = 0; - for (i = 0; i < m_CallArguments.top().second.size(); i++) { + // Loop through non-variadic function arguments + const size_t numArguments = type.getFunction().argTypes.size(); + for (size_t i = 0; i < numArguments; i++) { // If name contains a $(i) placeholder to replace with this argument, replace with pretty-printed argument const std::string placeholder = "$(" + std::to_string(i) + ")"; - // If placeholder isn't found at all, stop looking for arguments + // If placeholder isn't found at all, go onto next argument size_t found = name.find(placeholder); if(found == std::string::npos) { - break; + continue; } - // Keep replacing placeholders + // Replace all instances of placeholder do { name.replace(found, placeholder.length(), m_CallArguments.top().second.at(i)); found = name.find(placeholder, found); @@ -279,7 +279,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // between required and variadic arguments e.g. "printf($(0)$(@))" // so, arguments simply require leading printing with leading comma std::ostringstream variadicArgumentsStream; - const auto varArgBegin = m_CallArguments.top().second.cbegin() + i; + const auto varArgBegin = m_CallArguments.top().second.cbegin() + numArguments; const auto varArgEnd = m_CallArguments.top().second.cend(); for(auto a = varArgBegin; a != varArgEnd; a++) { variadicArgumentsStream << ", " << *a; diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 115b60299..7e8feec3f 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -558,7 +558,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor argumentConversionRank.reserve(m_CallArguments.top().size()); // Loop through arguments - // **NOTE** we loop through function arguments to deal with variadic + // **NOTE** we loop through function TYPE arguments to avoid variadic bool viable = true; auto c = m_CallArguments.top().cbegin(); auto a = argumentTypes.cbegin(); diff --git a/tests/features/test_event_propagation.py b/tests/features/test_event_propagation.py index d24597588..bdf48ccee 100644 --- a/tests/features/test_event_propagation.py +++ b/tests/features/test_event_propagation.py @@ -684,6 +684,167 @@ def test_reverse(make_model, backend, precision): assert np.sum(pre_bitmask_n_pop.vars["x"].view) == (model.timestep - 1) assert np.sum(pre_event_n_pop.vars["x"].view) == (model.timestep - 1) + +@pytest.mark.parametrize("precision", [types.Double, types.Float]) +def test_reverse_kernel(make_model, backend, precision): + pre_reverse_spike_source_model = create_neuron_model( + "pre_reverse_spike_source", + vars=[("startSpike", "unsigned int"), + ("endSpike", "unsigned int", VarAccess.READ_ONLY_DUPLICATE), + ("x", "scalar")], + extra_global_params=[("spikeTimes", "scalar*")], + sim_code= + """ + x = Isyn; + """, + threshold_condition_code= + """ + startSpike != endSpike && t >= spikeTimes[startSpike] + """, + reset_code= + """ + startSpike++; + """) + + static_pulse_reverse_model = create_weight_update_model( + "static_pulse_reverse", + sim_code= + """ + addToPre(g); + """, + vars=[("g", "scalar", VarAccess.READ_ONLY)]) + + model = make_model(precision, "test_reverse_kernel", backend=backend) + model.dt = 1.0 + + # Create spike source arrays with extra x variable + # to generate one-hot pattern to decode + pre_n_pop = model.add_neuron_population( + "Pre", 8 * 8, pre_reverse_spike_source_model, + {}, {"startSpike": np.arange(8 * 8), "endSpike": np.arange(1, 65), "x": 0.0}) + pre_n_pop.extra_global_params["spikeTimes"].set_init_values(np.arange(8 * 8)) + + # Add postsynptic population to connect to + post_n_pop = model.add_neuron_population( + "Post", 6 * 6, post_neuron_model, + {}, {"x": 0.0}) + + # Add convolutional toeplitz connectivity + conv_toeplitz_params = {"conv_kh": 3, "conv_kw": 3, + "conv_ih": 8, "conv_iw": 8, "conv_ic": 1, + "conv_oh": 6, "conv_ow": 6, "conv_oc": 1} + model.add_synapse_population( + "Synapse", "TOEPLITZ", + pre_n_pop, post_n_pop, + init_weight_update(static_pulse_reverse_model, {}, {"g": 1.0}), + init_postsynaptic("DeltaCurr"), + init_toeplitz_connectivity("Conv2D", conv_toeplitz_params)) + + # Build model and load + model.build() + model.load() + + counts = np.asarray( + [[1, 2, 3, 3, 3, 3, 2, 1], + [2, 4, 6, 6, 6, 6, 4, 2], + [3, 6, 9, 9, 9, 9, 6, 3], + [3, 6, 9, 9, 9, 9, 6, 3], + [3, 6, 9, 9, 9, 9, 6, 3], + [3, 6, 9, 9, 9, 9, 6, 3], + [2, 4, 6, 6, 6, 6, 4, 2], + [1, 2, 3, 3, 3, 3, 2, 1]]).flatten() + # Simulate 64 timesteps + while model.timestep < 64: + model.step_time() + + pre_n_pop.vars["x"].pull_from_device() + + if model.timestep > 1: + ind = model.timestep - 2 + assert pre_n_pop.vars["x"].values[ind] == counts[ind] + + +@pytest.mark.parametrize("precision", [types.Double, types.Float]) +def test_reverse_kernel_procedural(make_model, backend_simt, precision): + pre_reverse_spike_source_model = create_neuron_model( + "pre_reverse_spike_source", + vars=[("startSpike", "unsigned int"), + ("endSpike", "unsigned int", VarAccess.READ_ONLY_DUPLICATE), + ("x", "scalar")], + extra_global_params=[("spikeTimes", "scalar*")], + sim_code= + """ + x = Isyn; + """, + threshold_condition_code= + """ + startSpike != endSpike && t >= spikeTimes[startSpike] + """, + reset_code= + """ + startSpike++; + """) + + static_pulse_reverse_model = create_weight_update_model( + "static_pulse_reverse", + sim_code= + """ + addToPre(g); + """, + vars=[("g", "scalar", VarAccess.READ_ONLY)]) + + model = make_model(precision, "test_reverse_kernel_procedural", backend=backend_simt) + model.dt = 1.0 + + # Create spike source arrays with extra x variable + # to generate one-hot pattern to decode + pre_n_pop = model.add_neuron_population( + "Pre", 8 * 8, pre_reverse_spike_source_model, + {}, {"startSpike": np.arange(8 * 8), "endSpike": np.arange(1, 65), "x": 0.0}) + pre_n_pop.extra_global_params["spikeTimes"].set_init_values(np.arange(8 * 8)) + + # Add postsynptic population to connect to + post_n_pop = model.add_neuron_population( + "Post", 6 * 6, post_neuron_model, + {}, {"x": 0.0}) + + # Add convolutional toeplitz connectivity + conv_params = {"conv_kh": 3, "conv_kw": 3, + "conv_sh": 1, "conv_sw": 1, + "conv_padh": 0, "conv_padw": 0, + "conv_ih": 8, "conv_iw": 8, "conv_ic": 1, + "conv_oh": 6, "conv_ow": 6, "conv_oc": 1} + + model.add_synapse_population( + "Synapse", "PROCEDURAL_KERNELG", + pre_n_pop, post_n_pop, + init_weight_update(static_pulse_reverse_model, {}, {"g": 1.0}), + init_postsynaptic("DeltaCurr"), + init_sparse_connectivity("Conv2D", conv_params)) + + # Build model and load + model.build() + model.load() + + counts = np.asarray( + [[1, 2, 3, 3, 3, 3, 2, 1], + [2, 4, 6, 6, 6, 6, 4, 2], + [3, 6, 9, 9, 9, 9, 6, 3], + [3, 6, 9, 9, 9, 9, 6, 3], + [3, 6, 9, 9, 9, 9, 6, 3], + [3, 6, 9, 9, 9, 9, 6, 3], + [2, 4, 6, 6, 6, 6, 4, 2], + [1, 2, 3, 3, 3, 3, 2, 1]]).flatten() + # Simulate 64 timesteps + while model.timestep < 64: + model.step_time() + + pre_n_pop.vars["x"].pull_from_device() + + if model.timestep > 1: + ind = model.timestep - 2 + assert pre_n_pop.vars["x"].values[ind] == counts[ind] + @pytest.mark.parametrize("precision", [types.Double, types.Float]) def test_reverse_post(make_model, backend, precision): pre_reverse_model = create_neuron_model( @@ -792,3 +953,4 @@ def test_reverse_post(make_model, backend, precision): output_value = np.sum(output_place_values[output_binary]) if output_value != (model.timestep - 1): assert False, f"{pop.name} decoding incorrect ({output_value} rather than {model.timestep - 1})" +