Skip to content

Commit ec34d69

Browse files
Add type unification logic to handle LambdaDecl/Functor's with FuncType (#9010)
Issue: * Partially addresses #9003 * fixes the case of `foo2` Problem & Soluition: * Currently, generic inference fails if using lambda as argument to a function-type parameter since Slang does not handle this case during type unification. This missing logic was added. * `inferGenericArguments`->`tryUnifyTypes`->`TryUnifyFuncTypesByStructuralMatch` --------- Co-authored-by: James Helferty (NVIDIA) <[email protected]>
1 parent 1184252 commit ec34d69

File tree

4 files changed

+144
-15
lines changed

4 files changed

+144
-15
lines changed

source/slang/slang-check-constraint.cpp

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,16 +1027,89 @@ bool SemanticsVisitor::TryUnifyIntParam(
10271027
}
10281028
}
10291029

1030+
bool SemanticsVisitor::TryUnifyFunctorByStructuralMatch(
1031+
ConstraintSystem& constraints,
1032+
ValUnificationContext unifyCtx,
1033+
StructDecl* fstStructDecl,
1034+
FuncType* sndFuncType)
1035+
{
1036+
// Here we just need to find an invocation method for our functor
1037+
// to perform unification with.
1038+
// We do not validate the validity of the functor at this step,
1039+
// we only need to perform a reasonable unification so that constraints
1040+
// can correctly solve.
1041+
FuncDecl* functorInvokeMethod =
1042+
as<FuncDecl>(fstStructDecl->findLastDirectMemberDeclOfName(getName("()")));
1043+
if (!functorInvokeMethod)
1044+
return false;
1045+
1046+
return TryUnifyFuncTypesByStructuralMatch(
1047+
constraints,
1048+
unifyCtx,
1049+
getFuncType(this->getASTBuilder(), functorInvokeMethod),
1050+
sndFuncType);
1051+
}
1052+
1053+
bool SemanticsVisitor::TryUnifyFuncTypesByStructuralMatch(
1054+
ConstraintSystem& constraints,
1055+
ValUnificationContext unifyCtx,
1056+
FuncType* fstFunType,
1057+
FuncType* sndFunType)
1058+
{
1059+
const Index numParams = fstFunType->getParamCount();
1060+
if (numParams != sndFunType->getParamCount())
1061+
return false;
1062+
for (Index i = 0; i < numParams; ++i)
1063+
{
1064+
if (!TryUnifyTypes(
1065+
constraints,
1066+
unifyCtx,
1067+
fstFunType->getParamTypeWithModeWrapper(i),
1068+
sndFunType->getParamTypeWithModeWrapper(i)))
1069+
return false;
1070+
}
1071+
return TryUnifyTypes(
1072+
constraints,
1073+
unifyCtx,
1074+
fstFunType->getResultType(),
1075+
sndFunType->getResultType());
1076+
}
1077+
10301078
bool SemanticsVisitor::TryUnifyTypesByStructuralMatch(
10311079
ConstraintSystem& constraints,
10321080
ValUnificationContext unifyCtx,
10331081
QualType fst,
10341082
QualType snd)
10351083
{
1084+
if (auto sndDeclRefType = as<DeclRefType>(snd))
1085+
{
1086+
auto sndDeclRef = sndDeclRefType->getDeclRef();
1087+
1088+
if (auto sndStructDecl = as<StructDecl>(sndDeclRef))
1089+
{
1090+
if (auto fstFunType = as<FuncType>(fst))
1091+
return TryUnifyFunctorByStructuralMatch(
1092+
constraints,
1093+
unifyCtx,
1094+
sndStructDecl.getDecl(),
1095+
fstFunType);
1096+
}
1097+
}
1098+
10361099
if (auto fstDeclRefType = as<DeclRefType>(fst))
10371100
{
10381101
auto fstDeclRef = fstDeclRefType->getDeclRef();
10391102

1103+
if (auto fstStructDecl = as<StructDecl>(fstDeclRef))
1104+
{
1105+
if (auto sndFunType = as<FuncType>(snd))
1106+
return TryUnifyFunctorByStructuralMatch(
1107+
constraints,
1108+
unifyCtx,
1109+
fstStructDecl.getDecl(),
1110+
sndFunType);
1111+
}
1112+
10401113
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
10411114
if (typeParamDecl->parentDecl == constraints.genericDecl)
10421115
return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd);
@@ -1102,23 +1175,11 @@ bool SemanticsVisitor::TryUnifyTypesByStructuralMatch(
11021175
{
11031176
if (auto sndFunType = as<FuncType>(snd))
11041177
{
1105-
const Index numParams = fstFunType->getParamCount();
1106-
if (numParams != sndFunType->getParamCount())
1107-
return false;
1108-
for (Index i = 0; i < numParams; ++i)
1109-
{
1110-
if (!TryUnifyTypes(
1111-
constraints,
1112-
unifyCtx,
1113-
fstFunType->getParamTypeWithModeWrapper(i),
1114-
sndFunType->getParamTypeWithModeWrapper(i)))
1115-
return false;
1116-
}
1117-
return TryUnifyTypes(
1178+
return TryUnifyFuncTypesByStructuralMatch(
11181179
constraints,
11191180
unifyCtx,
1120-
fstFunType->getResultType(),
1121-
sndFunType->getResultType());
1181+
fstFunType,
1182+
sndFunType);
11221183
}
11231184
}
11241185
else if (auto expandType = as<ExpandType>(fst))

source/slang/slang-check-impl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,6 +2772,18 @@ struct SemanticsVisitor : public SemanticsContext
27722772
DeclRef<VarDeclBase> const& varRef,
27732773
IntVal* val);
27742774

2775+
bool TryUnifyFunctorByStructuralMatch(
2776+
ConstraintSystem& constraints,
2777+
ValUnificationContext unifyCtx,
2778+
StructDecl* fst,
2779+
FuncType* snd);
2780+
2781+
bool TryUnifyFuncTypesByStructuralMatch(
2782+
ConstraintSystem& constraints,
2783+
ValUnificationContext unifyCtx,
2784+
FuncType* fst,
2785+
FuncType* snd);
2786+
27752787
bool TryUnifyTypesByStructuralMatch(
27762788
ConstraintSystem& constraints,
27772789
ValUnificationContext unificationContext,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type
2+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -output-using-type
3+
4+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
5+
RWStructuredBuffer<int> outputBuffer;
6+
7+
// This test ensures that lambdas correctly unify-types with function-types so that
8+
// generic-parameters can be infered.
9+
func foo<let N : int>(f: functype()->vector<float, N>)->vector<float, N> { return f(); }
10+
11+
struct Functor : IFunc<vector<float, 1>>
12+
{
13+
vector<float, 1> operator()()
14+
{
15+
return 4.0;
16+
}
17+
}
18+
19+
func foo3<let N:int>(f: IFunc<vector<float, N>>) {}
20+
21+
[numthreads(1,1,1)]
22+
func computeMain()->void
23+
{
24+
//CHECK: 2
25+
outputBuffer[0] = (int)foo(() => 2);
26+
//CHECK: 3
27+
outputBuffer[1] = (int)foo(() => vector<float, 1>(3));
28+
//CHECK: 4
29+
Functor f;
30+
outputBuffer[2] = (int)foo(f);
31+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry computeMain
2+
3+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
4+
RWStructuredBuffer<int> outputBuffer;
5+
6+
func foo<let N : int>(f: functype()->vector<float, N>)->vector<float, N> { return f(); }
7+
8+
struct Functor : IFunc<Array<int,2>>
9+
{
10+
Array<int,2> operator()()
11+
{
12+
return Array<int,2>();
13+
}
14+
}
15+
16+
[numthreads(1,1,1)]
17+
func computeMain()->void
18+
{
19+
//CHECK: ([[# @LINE+1]]): error 39999
20+
outputBuffer[0] = (int)foo(() => Array<int,2>());
21+
22+
//CHECK: ([[# @LINE+2]]): error 39999
23+
Functor f;
24+
outputBuffer[2] = (int)foo(f);
25+
}

0 commit comments

Comments
 (0)