Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, PresynapticUpdat

// Write sum of presynaptic output to global memory
if(isPresynapticOutputRequired(sg, trueSpike)) {
groupEnv.printLine(backend.getAtomic(sg.getScalarType()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], lOutPre);");
synEnv.printLine(backend.getAtomic(sg.getScalarType()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], lOutPre);");
}

}
Expand Down
14 changes: 7 additions & 7 deletions src/genn/genn/transpiler/prettyPrinter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor
// Check that there are call arguments on the stack
assert(!m_CallArguments.empty());

// Loop through call arguments on top of stack
size_t i = 0;
for (i = 0; i < m_CallArguments.top().second.size(); i++) {
// Loop through non-variadic function arguments
const size_t numArguments = type.getFunction().argTypes.size();
for (size_t i = 0; i < numArguments; i++) {
// If name contains a $(i) placeholder to replace with this argument, replace with pretty-printed argument
const std::string placeholder = "$(" + std::to_string(i) + ")";

// If placeholder isn't found at all, stop looking for arguments
// If placeholder isn't found at all, go onto next argument
size_t found = name.find(placeholder);
if(found == std::string::npos) {
break;
continue;
}

// Keep replacing placeholders
// Replace all instances of placeholder
do {
name.replace(found, placeholder.length(), m_CallArguments.top().second.at(i));
found = name.find(placeholder, found);
Expand All @@ -279,7 +279,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor
// between required and variadic arguments e.g. "printf($(0)$(@))"
// so, arguments simply require leading printing with leading comma
std::ostringstream variadicArgumentsStream;
const auto varArgBegin = m_CallArguments.top().second.cbegin() + i;
const auto varArgBegin = m_CallArguments.top().second.cbegin() + numArguments;
const auto varArgEnd = m_CallArguments.top().second.cend();
for(auto a = varArgBegin; a != varArgEnd; a++) {
variadicArgumentsStream << ", " << *a;
Expand Down
2 changes: 1 addition & 1 deletion src/genn/genn/transpiler/typeChecker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor
argumentConversionRank.reserve(m_CallArguments.top().size());

// Loop through arguments
// **NOTE** we loop through function arguments to deal with variadic
// **NOTE** we loop through function TYPE arguments to avoid variadic
bool viable = true;
auto c = m_CallArguments.top().cbegin();
auto a = argumentTypes.cbegin();
Expand Down
162 changes: 162 additions & 0 deletions tests/features/test_event_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,167 @@ def test_reverse(make_model, backend, precision):
assert np.sum(pre_bitmask_n_pop.vars["x"].view) == (model.timestep - 1)
assert np.sum(pre_event_n_pop.vars["x"].view) == (model.timestep - 1)


@pytest.mark.parametrize("precision", [types.Double, types.Float])
def test_reverse_kernel(make_model, backend, precision):
pre_reverse_spike_source_model = create_neuron_model(
"pre_reverse_spike_source",
vars=[("startSpike", "unsigned int"),
("endSpike", "unsigned int", VarAccess.READ_ONLY_DUPLICATE),
("x", "scalar")],
extra_global_params=[("spikeTimes", "scalar*")],
sim_code=
"""
x = Isyn;
""",
threshold_condition_code=
"""
startSpike != endSpike && t >= spikeTimes[startSpike]
""",
reset_code=
"""
startSpike++;
""")

static_pulse_reverse_model = create_weight_update_model(
"static_pulse_reverse",
sim_code=
"""
addToPre(g);
""",
vars=[("g", "scalar", VarAccess.READ_ONLY)])

model = make_model(precision, "test_reverse_kernel", backend=backend)
model.dt = 1.0

# Create spike source arrays with extra x variable
# to generate one-hot pattern to decode
pre_n_pop = model.add_neuron_population(
"Pre", 8 * 8, pre_reverse_spike_source_model,
{}, {"startSpike": np.arange(8 * 8), "endSpike": np.arange(1, 65), "x": 0.0})
pre_n_pop.extra_global_params["spikeTimes"].set_init_values(np.arange(8 * 8))

# Add postsynptic population to connect to
post_n_pop = model.add_neuron_population(
"Post", 6 * 6, post_neuron_model,
{}, {"x": 0.0})

# Add convolutional toeplitz connectivity
conv_toeplitz_params = {"conv_kh": 3, "conv_kw": 3,
"conv_ih": 8, "conv_iw": 8, "conv_ic": 1,
"conv_oh": 6, "conv_ow": 6, "conv_oc": 1}
model.add_synapse_population(
"Synapse", "TOEPLITZ",
pre_n_pop, post_n_pop,
init_weight_update(static_pulse_reverse_model, {}, {"g": 1.0}),
init_postsynaptic("DeltaCurr"),
init_toeplitz_connectivity("Conv2D", conv_toeplitz_params))

# Build model and load
model.build()
model.load()

counts = np.asarray(
[[1, 2, 3, 3, 3, 3, 2, 1],
[2, 4, 6, 6, 6, 6, 4, 2],
[3, 6, 9, 9, 9, 9, 6, 3],
[3, 6, 9, 9, 9, 9, 6, 3],
[3, 6, 9, 9, 9, 9, 6, 3],
[3, 6, 9, 9, 9, 9, 6, 3],
[2, 4, 6, 6, 6, 6, 4, 2],
[1, 2, 3, 3, 3, 3, 2, 1]]).flatten()
# Simulate 64 timesteps
while model.timestep < 64:
model.step_time()

pre_n_pop.vars["x"].pull_from_device()

if model.timestep > 1:
ind = model.timestep - 2
assert pre_n_pop.vars["x"].values[ind] == counts[ind]


@pytest.mark.parametrize("precision", [types.Double, types.Float])
def test_reverse_kernel_procedural(make_model, backend_simt, precision):
pre_reverse_spike_source_model = create_neuron_model(
"pre_reverse_spike_source",
vars=[("startSpike", "unsigned int"),
("endSpike", "unsigned int", VarAccess.READ_ONLY_DUPLICATE),
("x", "scalar")],
extra_global_params=[("spikeTimes", "scalar*")],
sim_code=
"""
x = Isyn;
""",
threshold_condition_code=
"""
startSpike != endSpike && t >= spikeTimes[startSpike]
""",
reset_code=
"""
startSpike++;
""")

static_pulse_reverse_model = create_weight_update_model(
"static_pulse_reverse",
sim_code=
"""
addToPre(g);
""",
vars=[("g", "scalar", VarAccess.READ_ONLY)])

model = make_model(precision, "test_reverse_kernel_procedural", backend=backend_simt)
model.dt = 1.0

# Create spike source arrays with extra x variable
# to generate one-hot pattern to decode
pre_n_pop = model.add_neuron_population(
"Pre", 8 * 8, pre_reverse_spike_source_model,
{}, {"startSpike": np.arange(8 * 8), "endSpike": np.arange(1, 65), "x": 0.0})
pre_n_pop.extra_global_params["spikeTimes"].set_init_values(np.arange(8 * 8))

# Add postsynptic population to connect to
post_n_pop = model.add_neuron_population(
"Post", 6 * 6, post_neuron_model,
{}, {"x": 0.0})

# Add convolutional toeplitz connectivity
conv_params = {"conv_kh": 3, "conv_kw": 3,
"conv_sh": 1, "conv_sw": 1,
"conv_padh": 0, "conv_padw": 0,
"conv_ih": 8, "conv_iw": 8, "conv_ic": 1,
"conv_oh": 6, "conv_ow": 6, "conv_oc": 1}

model.add_synapse_population(
"Synapse", "PROCEDURAL_KERNELG",
pre_n_pop, post_n_pop,
init_weight_update(static_pulse_reverse_model, {}, {"g": 1.0}),
init_postsynaptic("DeltaCurr"),
init_sparse_connectivity("Conv2D", conv_params))

# Build model and load
model.build()
model.load()

counts = np.asarray(
[[1, 2, 3, 3, 3, 3, 2, 1],
[2, 4, 6, 6, 6, 6, 4, 2],
[3, 6, 9, 9, 9, 9, 6, 3],
[3, 6, 9, 9, 9, 9, 6, 3],
[3, 6, 9, 9, 9, 9, 6, 3],
[3, 6, 9, 9, 9, 9, 6, 3],
[2, 4, 6, 6, 6, 6, 4, 2],
[1, 2, 3, 3, 3, 3, 2, 1]]).flatten()
# Simulate 64 timesteps
while model.timestep < 64:
model.step_time()

pre_n_pop.vars["x"].pull_from_device()

if model.timestep > 1:
ind = model.timestep - 2
assert pre_n_pop.vars["x"].values[ind] == counts[ind]

@pytest.mark.parametrize("precision", [types.Double, types.Float])
def test_reverse_post(make_model, backend, precision):
pre_reverse_model = create_neuron_model(
Expand Down Expand Up @@ -792,3 +953,4 @@ def test_reverse_post(make_model, backend, precision):
output_value = np.sum(output_place_values[output_binary])
if output_value != (model.timestep - 1):
assert False, f"{pop.name} decoding incorrect ({output_value} rather than {model.timestep - 1})"