Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions include/genn/backends/cuda/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
AtomicOperation op = AtomicOperation::ADD,
AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -191,10 +188,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
virtual std::unique_ptr<Runtime::ArrayBase> createArray(const Type::ResolvedType &type, size_t count,
VarLocation location, bool uninitialized) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code to allocate variable with a size known at runtime
virtual void genLazyVariableDynamicAllocation(CodeStream &os,
const Type::ResolvedType &type, const std::string &name, VarLocation loc,
Expand Down Expand Up @@ -234,6 +227,9 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
return m_ChosenDevice.totalConstMem - getPreferences<Preferences>().constantCacheOverhead;
}

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const final;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const final;

Expand Down
10 changes: 3 additions & 7 deletions include/genn/backends/hip/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
AtomicOperation op = AtomicOperation::ADD,
AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -182,10 +179,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
virtual std::unique_ptr<Runtime::ArrayBase> createArray(const Type::ResolvedType &type, size_t count,
VarLocation location, bool uninitialized) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code to allocate variable with a size known at runtime
virtual void genLazyVariableDynamicAllocation(CodeStream &os,
const Type::ResolvedType &type, const std::string &name, VarLocation loc,
Expand Down Expand Up @@ -225,6 +218,9 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
return m_ChosenDevice.totalConstMem - getPreferences<Preferences>().constantCacheOverhead;
}

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const final;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const final;

Expand Down
27 changes: 16 additions & 11 deletions include/genn/genn/code_generator/backendCUDAHIP.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
m_RandPrefix(randPrefix), m_CCLPrefix(cclPrefix)
{}

//--------------------------------------------------------------------------
// Declared virtuals
//--------------------------------------------------------------------------

//--------------------------------------------------------------------------
// CodeGenerator::BackendSIMT virtuals
//--------------------------------------------------------------------------
Expand All @@ -87,16 +83,18 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
//! For SIMT backends which initialize RNGs on device, initialize population RNG with specified seed and sequence
virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const final;

//! Generate a preamble to add substitution name for population RNG
virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const final;

//! If required, generate a postamble for population RNG
/*! For example, in OpenCL, this is used to write local RNG state back to global memory*/
virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const final;

//! Generate code to skip ahead local copy of global RNG
virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//! Generate a preamble to add substitution name for population RNG
virtual void buildPopulationRNGEnvironment(EnvironmentGroupMergedField<NeuronUpdateGroupMerged> &env) const final;

//! Add $(_rng) to environment based on $(_rng_internal) field with any initialisers and destructors required
virtual void buildPopulationRNGEnvironment(EnvironmentGroupMergedField<CustomConnectivityUpdateGroupMerged> &env) const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -118,6 +116,10 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
virtual void genFreeMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final;
virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<GeNN::Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code for pushing a variable with a size known at runtime to the 'device'
virtual void genLazyVariableDynamicPush(CodeStream &os,
const Type::ResolvedType &type, const std::string &name,
Expand Down Expand Up @@ -173,6 +175,9 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
//! Get the safe amount of constant cache we can use
virtual size_t getChosenDeviceSafeConstMemBytes() const = 0;

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const = 0;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const = 0;

Expand Down
9 changes: 4 additions & 5 deletions include/genn/genn/code_generator/backendSIMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ class GENN_EXPORT BackendSIMT : public BackendBase
//! For SIMT backends which initialize RNGs on device, initialize population RNG with specified seed and sequence
virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const = 0;

//! Generate a preamble to add substitution name for population RNG
virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const = 0;
//! Add $(_rng) to environment based on $(_rng_internal) field with any initialisers and destructors required
virtual void buildPopulationRNGEnvironment(EnvironmentGroupMergedField<NeuronUpdateGroupMerged> &env) const = 0;

//! If required, generate a postamble for population RNG
/*! For example, in OpenCL, this is used to write local RNG state back to global memory*/
virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const = 0;
//! Add $(_rng) to environment based on $(_rng_internal) field with any initialisers and destructors required
virtual void buildPopulationRNGEnvironment(EnvironmentGroupMergedField<CustomConnectivityUpdateGroupMerged> &env) const = 0;

//! Generate code to skip ahead local copy of global RNG
virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const = 0;
Expand Down
81 changes: 62 additions & 19 deletions include/genn/genn/code_generator/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,14 @@

~EnvironmentExternalDynamicBase()
{
// Loop through initialisers
std::vector<std::string> initialiserCode(m_Initialisers.size());
std::vector<std::string> finaliserCode(m_Finalisers.size());

// Because initialisers may refer to other initialisers,
// keep evaluating initialisers until no new ones are founf
// Because initialisers and finalisers may refer to others,
// keep evaluating them until no new ones are found
bool anyReferences;
do {
// Loop through initialiser
// Loop through initialisers
anyReferences = false;
for(size_t i = 0; i < m_Initialisers.size(); i++) {
// If initialiser has been referenced
Expand All @@ -279,6 +279,20 @@
anyReferences = true;
}
}

// Loop through finalisers
for(size_t i = 0; i < m_Finalisers.size(); i++) {
// If finaliser has been referenced
auto &finaliser = m_Finalisers[i];
if (finaliser.first) {
// Evaluate lazy string into vector
finaliserCode[i] = finaliser.second.str();

// Clear referenced flag and set flag to ensure another iteration occurs
finaliser.first = false;
anyReferences = true;
}
}
} while(anyReferences);

// Write out generated initialiser code
Expand All @@ -291,6 +305,14 @@

// Write contents to context stream
getContextStream() << m_ContentsStream.str();

// Write out generated finaliser code
// **NOTE** in order
for(const auto &i : finaliserCode) {
if(!i.empty()) {
getContextStream() << i << std::endl;
}
}
}

//------------------------------------------------------------------------
Expand All @@ -310,10 +332,15 @@
m_Initialisers.at(i).first = true;
}

// If this identifier relies on any finaliser statements, mark these finalisers as required
for(size_t i : std::get<2>(env->second)) {
m_Finalisers.at(i).first = true;
}

// Perform any type-specific logic to mark this identifier as required
this->setRequired(std::get<2>(env->second));
this->setRequired(std::get<3>(env->second));

return this->getNameInternal(std::get<2>(env->second));
return this->getNameInternal(std::get<3>(env->second));
}
}

Expand All @@ -336,8 +363,13 @@
m_Initialisers.at(i).first = true;
}

// If this identifier relies on any finaliser statements, mark these finalisers as required
for(size_t i : std::get<2>(env->second)) {
m_Finalisers.at(i).first = true;

Check warning on line 368 in include/genn/genn/code_generator/environment.h

View check run for this annotation

Codecov / codecov/patch

include/genn/genn/code_generator/environment.h#L368

Added line #L368 was not covered by tests
}

// Perform any type-specific logic to mark this identifier as required
this->setRequired(std::get<2>(env->second));
this->setRequired(std::get<3>(env->second));

// Return type of variables
return {std::get<0>(env->second)};
Expand All @@ -350,15 +382,21 @@
return (m_Initialisers.size() - 1);
}

size_t addFinaliser(const std::string &format)
{
m_Finalisers.emplace_back(false, LazyString{format, *this});
return (m_Finalisers.size() - 1);
}

protected:
//------------------------------------------------------------------------
// Protected API
//------------------------------------------------------------------------
//! Map an identifier to a type (for type-checking), lists of initialisers and a payload
void addInternal(const GeNN::Type::ResolvedType &type, const std::string &name, const typename P::Payload &payload,
const std::vector<size_t> &initialisers = {})
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
if(!m_Environment.try_emplace(name, type, initialisers, payload).second) {
if(!m_Environment.try_emplace(name, type, initialisers, finalisers, payload).second) {
throw std::runtime_error("Redeclaration of '" + std::string{name} + "'");
}
}
Expand All @@ -370,8 +408,9 @@
std::ostringstream m_ContentsStream;
CodeStream m_Contents;

std::unordered_map<std::string, std::tuple<Type::ResolvedType, std::vector<size_t>, typename P::Payload>> m_Environment;
std::unordered_map<std::string, std::tuple<Type::ResolvedType, std::vector<size_t>, std::vector<size_t>, typename P::Payload>> m_Environment;
std::vector<std::pair<bool, LazyString>> m_Initialisers;
std::vector<std::pair<bool, LazyString>> m_Finalisers;
};

//----------------------------------------------------------------------------
Expand All @@ -389,9 +428,10 @@
//------------------------------------------------------------------------
//! Map a type (for type-checking) and a value (for pretty-printing) to an identifier
void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value,
const std::vector<size_t> &initialisers = {})
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
addInternal(type, name, LazyString{value, *this}, initialisers);
addInternal(type, name, LazyString{value, *this},
initialisers, finalisers);
}
};

Expand Down Expand Up @@ -433,9 +473,10 @@
//------------------------------------------------------------------------
//! Map a type and a value to an identifier
void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value,
const std::vector<size_t> &initialisers = {})
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
this->addInternal(type, name, std::make_tuple(false, LazyString{value, *this}, std::nullopt), initialisers);
this->addInternal(type, name, std::make_tuple(false, LazyString{value, *this}, std::nullopt),
initialisers, finalisers);
}

//! Map a type (for type-checking) and a group merged field to back it to an identifier
Expand Down Expand Up @@ -472,7 +513,7 @@
const GeNN::Type::ResolvedType &fieldType, const std::string &fieldName,
GetFieldNonNumericValueFunc getFieldValue, const std::string &indexSuffix = "",
GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD,
const std::vector<size_t> &initialisers = {})
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
typename G::Field field{fieldName, fieldType, mergedFieldType,
[getFieldValue](Runtime::Runtime &r, const GroupInternal &g, size_t i)
Expand All @@ -482,7 +523,7 @@
getFieldValue(r, g, i));
}};
this->addInternal(type, name, std::make_tuple(false, LazyString{indexSuffix, *this}, std::make_optional(field)),
initialisers);
initialisers, finalisers);
}

//! Map a type (for type-checking) and a group merged field to back it to an identifier
Expand All @@ -495,9 +536,11 @@
//! Map a type (for type-checking) and a group merged field to back it to an identifier
void addField(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &fieldName,
GetFieldNonNumericValueFunc getFieldValue, const std::string &indexSuffix = "",
GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, const std::vector<size_t> &initialisers = {})
GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD,
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
addField(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType, initialisers);
addField(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType,
initialisers, finalisers);
}

void addParams(const Snippet::Base::ParamVec &params, const std::string &fieldSuffix,
Expand Down Expand Up @@ -611,7 +654,7 @@
}

template<typename I>
void addInitialiserDerivedParams(const std::string &fieldSuffix, GetInitialiserFn<I> getInitialiser)
void addInitialiserDerivedParams(const std::string &fieldSuffix, GetInitialiserFn<I> getInitialiser)
{
// Loop through params
const auto &initialiser = std::invoke(getInitialiser, this->getGroup().getArchetype());
Expand Down
6 changes: 6 additions & 0 deletions include/genn/genn/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ struct GENN_EXPORT ResolvedType
return ResolvedType{Value{name, sizeof(T), ffiType, device, false, std::nullopt}, isConst};
}

static ResolvedType createValue(const std::string &name, size_t size, bool isConst = false,
ffi_type *ffiType = nullptr, bool device = false)
{
return ResolvedType{Value{name, size, ffiType, device, false, std::nullopt}, isConst};
}

static ResolvedType createFunction(const ResolvedType &returnType, const std::vector<ResolvedType> &argTypes,
FunctionFlags flags=FunctionFlags{0})
{
Expand Down
Loading