Skip to content
Merged
91 changes: 76 additions & 15 deletions source/slang/slang-check-constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,16 +1027,89 @@ bool SemanticsVisitor::TryUnifyIntParam(
}
}

bool SemanticsVisitor::TryUnifyFunctorByStructuralMatch(
ConstraintSystem& constraints,
ValUnificationContext unifyCtx,
StructDecl* fstStructDecl,
FuncType* sndFuncType)
{
// Here we just need to find an invocation method for our functor
// to perform unification with.
// We do not validate the validity of the functor at this step,
// we only need to perform a reasonable unification so that constraints
// can correctly solve.
FuncDecl* functorInvokeMethod =
as<FuncDecl>(fstStructDecl->findLastDirectMemberDeclOfName(getName("()")));
if (!functorInvokeMethod)
return false;

return TryUnifyFuncTypesByStructuralMatch(
constraints,
unifyCtx,
getFuncType(this->getASTBuilder(), functorInvokeMethod),
sndFuncType);
}

bool SemanticsVisitor::TryUnifyFuncTypesByStructuralMatch(
ConstraintSystem& constraints,
ValUnificationContext unifyCtx,
FuncType* fstFunType,
FuncType* sndFunType)
{
const Index numParams = fstFunType->getParamCount();
if (numParams != sndFunType->getParamCount())
return false;
for (Index i = 0; i < numParams; ++i)
{
if (!TryUnifyTypes(
constraints,
unifyCtx,
fstFunType->getParamTypeWithModeWrapper(i),
sndFunType->getParamTypeWithModeWrapper(i)))
return false;
}
return TryUnifyTypes(
constraints,
unifyCtx,
fstFunType->getResultType(),
sndFunType->getResultType());
}

bool SemanticsVisitor::TryUnifyTypesByStructuralMatch(
ConstraintSystem& constraints,
ValUnificationContext unifyCtx,
QualType fst,
QualType snd)
{
if (auto sndDeclRefType = as<DeclRefType>(snd))
{
auto sndDeclRef = sndDeclRefType->getDeclRef();

if (auto sndStructDecl = as<StructDecl>(sndDeclRef))
{
if (auto fstFunType = as<FuncType>(fst))
return TryUnifyFunctorByStructuralMatch(
constraints,
unifyCtx,
sndStructDecl.getDecl(),
fstFunType);
}
}

if (auto fstDeclRefType = as<DeclRefType>(fst))
{
auto fstDeclRef = fstDeclRefType->getDeclRef();

if (auto fstStructDecl = as<StructDecl>(fstDeclRef))
{
if (auto sndFunType = as<FuncType>(snd))
return TryUnifyFunctorByStructuralMatch(
constraints,
unifyCtx,
fstStructDecl.getDecl(),
sndFunType);
}

if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
if (typeParamDecl->parentDecl == constraints.genericDecl)
return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd);
Expand Down Expand Up @@ -1102,23 +1175,11 @@ bool SemanticsVisitor::TryUnifyTypesByStructuralMatch(
{
if (auto sndFunType = as<FuncType>(snd))
{
const Index numParams = fstFunType->getParamCount();
if (numParams != sndFunType->getParamCount())
return false;
for (Index i = 0; i < numParams; ++i)
{
if (!TryUnifyTypes(
constraints,
unifyCtx,
fstFunType->getParamTypeWithModeWrapper(i),
sndFunType->getParamTypeWithModeWrapper(i)))
return false;
}
return TryUnifyTypes(
return TryUnifyFuncTypesByStructuralMatch(
constraints,
unifyCtx,
fstFunType->getResultType(),
sndFunType->getResultType());
fstFunType,
sndFunType);
}
}
else if (auto expandType = as<ExpandType>(fst))
Expand Down
12 changes: 12 additions & 0 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2772,6 +2772,18 @@ struct SemanticsVisitor : public SemanticsContext
DeclRef<VarDeclBase> const& varRef,
IntVal* val);

bool TryUnifyFunctorByStructuralMatch(
ConstraintSystem& constraints,
ValUnificationContext unifyCtx,
StructDecl* fst,
FuncType* snd);

bool TryUnifyFuncTypesByStructuralMatch(
ConstraintSystem& constraints,
ValUnificationContext unifyCtx,
FuncType* fst,
FuncType* snd);

bool TryUnifyTypesByStructuralMatch(
ConstraintSystem& constraints,
ValUnificationContext unificationContext,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -output-using-type

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<int> outputBuffer;

// This test ensures that lambdas correctly unify-types with function-types so that
// generic-parameters can be infered.
func foo<let N : int>(f: functype()->vector<float, N>)->vector<float, N> { return f(); }

struct Functor : IFunc<vector<float, 1>>
{
vector<float, 1> operator()()
{
return 4.0;
}
}

func foo3<let N:int>(f: IFunc<vector<float, N>>) {}

[numthreads(1,1,1)]
func computeMain()->void
{
//CHECK: 2
outputBuffer[0] = (int)foo(() => 2);
//CHECK: 3
outputBuffer[1] = (int)foo(() => vector<float, 1>(3));
//CHECK: 4
Functor f;
outputBuffer[2] = (int)foo(f);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry computeMain

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<int> outputBuffer;

func foo<let N : int>(f: functype()->vector<float, N>)->vector<float, N> { return f(); }

struct Functor : IFunc<Array<int,2>>
{
Array<int,2> operator()()
{
return Array<int,2>();
}
}

[numthreads(1,1,1)]
func computeMain()->void
{
//CHECK: ([[# @LINE+1]]): error 39999
outputBuffer[0] = (int)foo(() => Array<int,2>());

//CHECK: ([[# @LINE+2]]): error 39999
Functor f;
outputBuffer[2] = (int)foo(f);
}