Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions include/genn/backends/cuda/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,25 +345,26 @@ class BACKEND_EXPORT Backend : public BackendSIMT

// Implement merged group array in previously assigned memory space
os << g.getMemorySpace() << " Merged" << T::name << "Group" << g.getIndex() << " d_merged" << T::name << "Group" << g.getIndex() << "[" << g.getGroups().size() << "];" << std::endl;

// Write function to update
os << "void pushMerged" << T::name << "Group" << g.getIndex() << "ToDevice(unsigned int idx, ";
g.generateStructFieldArgumentDefinitions(os, *this);
os << ")";
{
CodeStream::Scope b(os);

// Loop through sorted fields and build struct on the stack
os << "Merged" << T::name << "Group" << g.getIndex() << " group = {";
const auto sortedFields = g.getSortedFields(*this);
for(const auto &f : sortedFields) {
os << f.name << ", ";
if(!g.getFields().empty()) {
// Write function to update
os << "void pushMerged" << T::name << "Group" << g.getIndex() << "ToDevice(unsigned int idx, ";
g.generateStructFieldArgumentDefinitions(os, *this);
os << ")";
{
CodeStream::Scope b(os);

// Loop through sorted fields and build struct on the stack
os << "Merged" << T::name << "Group" << g.getIndex() << " group = {";
const auto sortedFields = g.getSortedFields(*this);
for(const auto &f : sortedFields) {
os << f.name << ", ";
}
os << "};" << std::endl;

// Push to device
os << "CHECK_CUDA_ERRORS(cudaMemcpyToSymbolAsync(d_merged" << T::name << "Group" << g.getIndex() << ", &group, ";
os << "sizeof(Merged" << T::name << "Group" << g.getIndex() << "), idx * sizeof(Merged" << T::name << "Group" << g.getIndex() << ")));" << std::endl;
}
os << "};" << std::endl;

// Push to device
os << "CHECK_CUDA_ERRORS(cudaMemcpyToSymbolAsync(d_merged" << T::name << "Group" << g.getIndex() << ", &group, ";
os << "sizeof(Merged" << T::name << "Group" << g.getIndex() << "), idx * sizeof(Merged" << T::name << "Group" << g.getIndex() << ")));" << std::endl;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged<Custo
//----------------------------------------------------------------------------
// Private methods
//----------------------------------------------------------------------------
bool isParamHeterogeneous(const std::string &name) const;
bool isDerivedParamHeterogeneous(const std::string &name) const;

template<typename A>
void addPrivateVarRefAccess(EnvironmentGroupMergedField<CustomConnectivityUpdateGroupMerged> &env, unsigned int batchSize,
std::function<std::string(VarAccessMode, const typename A::RefType&)> getIndexFn)
Expand Down Expand Up @@ -143,9 +140,6 @@ class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public GroupMerged<C
//----------------------------------------------------------------------------
// Private methods
//----------------------------------------------------------------------------
bool isParamHeterogeneous(const std::string &name) const;
bool isDerivedParamHeterogeneous(const std::string &name) const;

template<typename A>
void addVars(EnvironmentGroupMergedField<CustomConnectivityHostUpdateGroupMerged> &env, const std::string &count, const BackendBase &backend)
{
Expand Down
10 changes: 0 additions & 10 deletions include/genn/genn/code_generator/customUpdateGroupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged<CustomUpdateInter
// Static constants
//----------------------------------------------------------------------------
static const std::string name;

private:
//----------------------------------------------------------------------------
// Private methods
//----------------------------------------------------------------------------
bool isParamHeterogeneous(const std::string &paramName) const;
bool isDerivedParamHeterogeneous(const std::string &paramName) const;
};

// ----------------------------------------------------------------------------
Expand All @@ -57,9 +50,6 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged<CustomUpdat
//----------------------------------------------------------------------------
// Public API
//----------------------------------------------------------------------------
bool isParamHeterogeneous(const std::string &paramName) const;
bool isDerivedParamHeterogeneous(const std::string &paramName) const;

boost::uuids::detail::sha1::digest_type getHashDigest() const;

void generateCustomUpdate(EnvironmentExternalBase &env, unsigned int batchSize,
Expand Down
Loading