Skip to content

Commit d90d869

Browse files
authored
Fix generic function overloads in reflection API (#9052)
Close #8633. When calling `findFunctionByName<someGeneric<concrete_type>>` to find a generic function, it will report nullptr if there are overloads for the generic function. The reason is that we didn't handle OverloadedExpr2 node in the reflection API. The fix is to conservatively convert OverloadedExpr2 to OverloadedExpr and reuse to existing logic to keep handling the OverloadedExpr. The alternative solution will be add logic of handling OverloadedExpr2 to every location that only handles OverloadedExpr. But it might be more complicated.
1 parent 1152af7 commit d90d869

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

source/slang/slang-linkable.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,74 @@ Expr* ComponentType::tryResolveOverloadedExpr(Expr* exprIn)
813813
return visitor.maybeResolveOverloadedExpr(exprIn, LookupMask::Function, nullptr);
814814
}
815815

816+
// This function tries to simplify an overloaded expr into OverloadedExpr for reflection API usage.
817+
// There are two kinds of overloaded expr in the AST: OverloadedExpr and OverloadedExpr2.
818+
//
819+
// OverloadedExpr stores candidates in LookupResult, where a list of `DeclRef<Decl>` hold
820+
// the properly-specialized reference to the declaration that was found. And all the candidates
821+
// must share a same base (if it is coming from a member-reference), and same orignalExpr.
822+
//
823+
// While OverloadedExpr2 stores candidates in a list of Expr, which is not necessary to be
824+
// DeclRefExpr.
825+
//
826+
// When the input orignalExpr is already OverloadedExpr, we can directly return it. But when
827+
// the input orignalExpr is OverloadedExpr2, we need to simplify it by converting it into
828+
// OverloadedExpr. The conversion routine conservatively performs the conversion when each Expr
829+
// candidates of OverloadedExpr2 is DeclRefExpr and all the candidates DeclRefExpr share the same
830+
// orignalExpr. If such condition is not met, it will return nullptr to indicated failed conversion.
831+
static Expr* maybeSimplifyExprForReflectionAPIUsage(Expr* originalExpr, ASTBuilder* astBuilder)
832+
{
833+
// return directly if it is already OverloadedExpr
834+
if (as<OverloadedExpr>(originalExpr))
835+
return originalExpr;
836+
837+
OverloadedExpr2* overloadedExpr2 = as<OverloadedExpr2>(originalExpr);
838+
// Don't perform any conversion if it is not OverloadedExpr2
839+
if (!overloadedExpr2)
840+
return originalExpr;
841+
842+
if (!overloadedExpr2->candidateExprs.getCount())
843+
return nullptr;
844+
845+
auto overloadedExpr = astBuilder->create<OverloadedExpr>();
846+
847+
Expr* sharedOriginalExpr = nullptr;
848+
849+
// Start the conversion
850+
for (auto candidate : overloadedExpr2->candidateExprs)
851+
{
852+
if (auto declRefExpr = as<DeclRefExpr>(candidate))
853+
{
854+
LookupResultItem item;
855+
item.declRef = declRefExpr->declRef;
856+
overloadedExpr->lookupResult2.items.add(item);
857+
858+
if (!sharedOriginalExpr)
859+
{
860+
sharedOriginalExpr = declRefExpr->originalExpr;
861+
}
862+
else if (sharedOriginalExpr != declRefExpr->originalExpr)
863+
{
864+
return nullptr;
865+
}
866+
}
867+
else
868+
{
869+
return nullptr;
870+
}
871+
}
872+
873+
if (overloadedExpr->lookupResult2.items.getCount())
874+
{
875+
overloadedExpr->lookupResult2.item = overloadedExpr->lookupResult2.items[0];
876+
overloadedExpr->base = overloadedExpr2->base;
877+
overloadedExpr->originalExpr = sharedOriginalExpr;
878+
return overloadedExpr;
879+
}
880+
881+
return nullptr;
882+
}
883+
816884
Expr* ComponentType::findDeclFromString(String const& name, DiagnosticSink* sink)
817885
{
818886
// If we've looked up this type name before,
@@ -855,6 +923,7 @@ Expr* ComponentType::findDeclFromString(String const& name, DiagnosticSink* sink
855923
{
856924
result = checkedExpr;
857925
}
926+
result = maybeSimplifyExprForReflectionAPIUsage(checkedExpr, astBuilder);
858927

859928
m_decls[name] = result;
860929
return result;
@@ -946,6 +1015,8 @@ Expr* ComponentType::findDeclFromStringInType(
9461015

9471016
auto checkedTerm = visitor.CheckTerm(expr);
9481017

1018+
checkedTerm = maybeSimplifyExprForReflectionAPIUsage(checkedTerm, astBuilder);
1019+
9491020
if (auto overloadedExpr = as<OverloadedExpr>(checkedTerm))
9501021
{
9511022
// For functions, since we don't know the argument list yet, we will have to defer

tools/slang-unit-test/unit-test-function-reflection.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,104 @@ SLANG_UNIT_TEST(findFunctionByNameInType)
326326
SLANG_CHECK_ABORT(false && "Expected function to be overloaded with multiple signatures");
327327
}
328328
}
329+
330+
331+
SLANG_UNIT_TEST(findFunctionByNameGenericOverload)
332+
{
333+
// Test shader with extensions that have functions with same name but different signatures
334+
const char* userSourceBody = R"(
335+
336+
void myFunc<T>(T value)
337+
{}
338+
339+
void myFunc<T>(T value, T value1)
340+
{}
341+
342+
struct MyType<U>
343+
{
344+
void myFunc<T>(T value)
345+
{}
346+
347+
void myFunc<T>(T value, T value1)
348+
{}
349+
}
350+
351+
[shader("compute")]
352+
void computeMain(uint3 tid: SV_DispatchThreadID)
353+
{
354+
}
355+
)";
356+
357+
auto moduleName = "moduleH" + String(Process::getId());
358+
String userSource = "import " + moduleName + ";\n" + userSourceBody;
359+
ComPtr<slang::IGlobalSession> globalSession;
360+
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
361+
slang::TargetDesc targetDesc = {};
362+
targetDesc.format = SLANG_HLSL;
363+
targetDesc.profile = globalSession->findProfile("sm_5_0");
364+
slang::SessionDesc sessionDesc = {};
365+
sessionDesc.targetCount = 1;
366+
sessionDesc.targets = &targetDesc;
367+
ComPtr<slang::ISession> session;
368+
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
369+
370+
ComPtr<slang::IBlob> diagnosticBlob;
371+
auto module = session->loadModuleFromSourceString(
372+
"test_module",
373+
"test_module.slang",
374+
userSourceBody,
375+
diagnosticBlob.writeRef());
376+
SLANG_CHECK(module != nullptr);
377+
378+
int testCount = 2;
379+
for (int i = 0; i < testCount; i++)
380+
{
381+
slang::FunctionReflection* myFunctionType = nullptr;
382+
if (i == 0)
383+
{
384+
// test the global generic function overloads
385+
myFunctionType = module->getLayout()->findFunctionByName("myFunc<int>");
386+
SLANG_CHECK_ABORT(myFunctionType != nullptr);
387+
}
388+
else
389+
{
390+
391+
// test the generic function overloads inside a generic struct
392+
auto myStructType = module->getLayout()->findTypeByName("MyType<float>");
393+
SLANG_CHECK_ABORT(myStructType != nullptr);
394+
395+
myFunctionType =
396+
module->getLayout()->findFunctionByNameInType(myStructType, "myFunc<int>");
397+
SLANG_CHECK_ABORT(myFunctionType != nullptr);
398+
}
399+
400+
// The function should be overloaded since there are multiple functions with different
401+
// signatures
402+
if (myFunctionType->isOverloaded())
403+
{
404+
// If it's overloaded, verify we can access both variants
405+
SLANG_CHECK(myFunctionType->getOverloadCount() >= 2);
406+
407+
for (int i = 0; i < myFunctionType->getOverloadCount(); i++)
408+
{
409+
auto overload = myFunctionType->getOverload(i);
410+
411+
for (int j = 0; j < overload->getParameterCount(); j++)
412+
{
413+
auto paramTypeName = overload->getParameterByIndex(j)->getType()->getName();
414+
if (strcmp(paramTypeName, "int") != 0)
415+
{
416+
SLANG_CHECK(false && "Expected different parameter signatures");
417+
}
418+
}
419+
}
420+
}
421+
else
422+
{
423+
// The function should be overloaded since there are multiple functions with different
424+
// signatures. If it's not overloaded, the fix didn't work properly.
425+
SLANG_CHECK_ABORT(
426+
false && "Expected function to be overloaded with multiple signatures");
427+
}
428+
}
429+
}

0 commit comments

Comments
 (0)