Skip to content

Commit

Permalink
Fix type checking around generic array types. (#3568)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Feb 11, 2024
1 parent 03cddba commit 4f7d1f4
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 23 deletions.
4 changes: 4 additions & 0 deletions source/slang/slang-ast-builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element
{
elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue());
}
else
{
elementCount = getTypeCastIntVal(getIntType(), elementCount);
}
}
Val* args[] = {elementType, elementCount};
return as<ArrayExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType"));
Expand Down
57 changes: 34 additions & 23 deletions source/slang/slang-ir-specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1323,13 +1323,24 @@ struct SpecializationContext
// their replacements.
//
IRCloneEnv cloneEnv;
cloneEnv.squashChildrenMapping = true;

// We also need some IR building state, for any
// new instructions we will emit.
//
IRBuilder builderStorage(module);
auto builder = &builderStorage;

// To get started, we will create the skeleton of the new
// specialized function, so newly created insts
// will be placed in a proper parent.
//

IRFunc* newFunc = builder->createFunc();

builder->setInsertInto(newFunc);
IRBlock* tempHeaderBlock = builder->emitBlock();

// We will start out by determining what the parameters
// of the specialized function should be, based on
// the parameters of the original, and the concrete
Expand All @@ -1343,7 +1354,6 @@ struct SpecializationContext
// block, or even a function, to insert them into.
//
List<IRParam*> newParams;
List<IRInst*> newBodyInsts;
UInt argCounter = 0;
for (auto oldParam : oldFunc->getParams())
{
Expand Down Expand Up @@ -1387,7 +1397,6 @@ struct SpecializationContext
// correct existential type, and stores the right witness table).
//
auto newMakeExistential = builder->emitMakeExistential(oldParam->getFullType(), newParam, witnessTable);
newBodyInsts.add(newMakeExistential);
replacementVal = newMakeExistential;
}
else if (auto oldWrapExistential = as<IRWrapExistential>(arg))
Expand All @@ -1412,7 +1421,6 @@ struct SpecializationContext
newParam,
oldWrapExistential->getSlotOperandCount(),
oldWrapExistential->getSlotOperands());
newBodyInsts.add(newWrapExistential);
replacementVal = newWrapExistential;
}
else
Expand All @@ -1433,8 +1441,20 @@ struct SpecializationContext
cloneEnv.mapOldValToNew.add(oldParam, replacementVal);
}

// Next we will create the skeleton of the new
// specialized function, including its type.
// The above steps have accomplished the "first phase"
// of cloning the function (since `IRFunc`s have no
// operands).
//
// We can now use the shared IR cloning infrastructure
// to perform the second phase of cloning, which will recursively
// clone any nested decorations, blocks, and instructions.
//
cloneInstDecorationsAndChildren(
&cloneEnv,
builder->getModule(),
oldFunc,
newFunc);

//
// In order to construct the type of the new function, we
// need to extract the types of all its parameters.
Expand All @@ -1448,24 +1468,9 @@ struct SpecializationContext
newParamTypes.getCount(),
newParamTypes.getBuffer(),
oldFunc->getResultType());
IRFunc* newFunc = builder->createFunc();
newFunc->setFullType(newFuncType);

// The above steps have accomplished the "first phase"
// of cloning the function (since `IRFunc`s have no
// operands).
//
// We can now use the shared IR cloning infrastructure
// to perform the second phase of cloning, which will recursively
// clone any nested decorations, blocks, and instructions.
//
cloneInstDecorationsAndChildren(
&cloneEnv,
builder->getModule(),
oldFunc,
newFunc);

// Now that the main body of existing isntructions have
// Now that the main body of existing instructions have
// been cloned into the new function, we can go ahead
// and insert all the parameters and body instructions
// we built up into the function at the right place.
Expand All @@ -1474,7 +1479,7 @@ struct SpecializationContext
// block (this was an invariant established before
// we decided to specialize).
//
auto newEntryBlock = newFunc->getFirstBlock();
auto newEntryBlock = as<IRBlock>(cloneEnv.mapOldValToNew[oldFunc->getFirstBlock()]);
SLANG_ASSERT(newEntryBlock);

// We expect every valid block to have at least one
Expand All @@ -1497,11 +1502,17 @@ struct SpecializationContext
// before the first ordinary instruction (but will come
// *after* the parameters by the order of these two loops).
//
for (auto newBodyInst : newBodyInsts)
for (auto newBodyInst = tempHeaderBlock->getFirstChild(); newBodyInst;)
{
auto next = newBodyInst->next;
newBodyInst->insertBefore(newFirstOrdinary);
newBodyInst = next;
}

// After moving all param and existential insts in tempHeaderBlock
// it should be empty now and we can remove it.
tempHeaderBlock->removeAndDeallocate();

// After all this work we have a valid `newFunc` that has been
// specialized to match the types at the call site.
//
Expand Down
42 changes: 42 additions & 0 deletions tests/language-feature/generics/generic-interface-2.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type

interface IFoo<let n: uint>
{
static int sum(int arr[n]);
}

struct MyType<let n:uint> : IFoo<n>
{
static int sum(int arr[n])
{
int rs = 0;
for (int i = 0; i < n; i++)
rs += arr[i];
return rs;
}
}

int test<let n:uint>(IFoo<n> foo)
{
int arr[n];
for (int i =0; i < n; i++)
arr[i] = i;
return foo.sum(arr);
}

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer;

[numthreads(1, 1, 1)]
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
{
MyType<3> t;
test(t);
// CHECK: 3
outputBuffer[0] = test(t);

int arr[3] = {1,2,3};
// CHECK: 3
outputBuffer[1] = MyType<3>.sum(arr);
}

0 comments on commit 4f7d1f4

Please sign in to comment.