From 4f7d1f44a4b2a5eab2e2dec1edf3a156da78aae3 Mon Sep 17 00:00:00 2001 From: Yong He Date: Sun, 11 Feb 2024 01:05:37 -0800 Subject: [PATCH] Fix type checking around generic array types. (#3568) --- source/slang/slang-ast-builder.cpp | 4 ++ source/slang/slang-ir-specialize.cpp | 57 +++++++++++-------- .../generics/generic-interface-2.slang | 42 ++++++++++++++ 3 files changed, 80 insertions(+), 23 deletions(-) create mode 100644 tests/language-feature/generics/generic-interface-2.slang diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index ce13ab650c..0c30366e8a 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -338,6 +338,10 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element { elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue()); } + else + { + elementCount = getTypeCastIntVal(getIntType(), elementCount); + } } Val* args[] = {elementType, elementCount}; return as(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType")); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 60001661c6..9de574a9b1 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1323,6 +1323,7 @@ struct SpecializationContext // their replacements. // IRCloneEnv cloneEnv; + cloneEnv.squashChildrenMapping = true; // We also need some IR building state, for any // new instructions we will emit. @@ -1330,6 +1331,16 @@ struct SpecializationContext IRBuilder builderStorage(module); auto builder = &builderStorage; + // To get started, we will create the skeleton of the new + // specialized function, so newly created insts + // will be placed in a proper parent. + // + + IRFunc* newFunc = builder->createFunc(); + + builder->setInsertInto(newFunc); + IRBlock* tempHeaderBlock = builder->emitBlock(); + // We will start out by determining what the parameters // of the specialized function should be, based on // the parameters of the original, and the concrete @@ -1343,7 +1354,6 @@ struct SpecializationContext // block, or even a function, to insert them into. // List newParams; - List newBodyInsts; UInt argCounter = 0; for (auto oldParam : oldFunc->getParams()) { @@ -1387,7 +1397,6 @@ struct SpecializationContext // correct existential type, and stores the right witness table). // auto newMakeExistential = builder->emitMakeExistential(oldParam->getFullType(), newParam, witnessTable); - newBodyInsts.add(newMakeExistential); replacementVal = newMakeExistential; } else if (auto oldWrapExistential = as(arg)) @@ -1412,7 +1421,6 @@ struct SpecializationContext newParam, oldWrapExistential->getSlotOperandCount(), oldWrapExistential->getSlotOperands()); - newBodyInsts.add(newWrapExistential); replacementVal = newWrapExistential; } else @@ -1433,8 +1441,20 @@ struct SpecializationContext cloneEnv.mapOldValToNew.add(oldParam, replacementVal); } - // Next we will create the skeleton of the new - // specialized function, including its type. + // The above steps have accomplished the "first phase" + // of cloning the function (since `IRFunc`s have no + // operands). + // + // We can now use the shared IR cloning infrastructure + // to perform the second phase of cloning, which will recursively + // clone any nested decorations, blocks, and instructions. + // + cloneInstDecorationsAndChildren( + &cloneEnv, + builder->getModule(), + oldFunc, + newFunc); + // // In order to construct the type of the new function, we // need to extract the types of all its parameters. @@ -1448,24 +1468,9 @@ struct SpecializationContext newParamTypes.getCount(), newParamTypes.getBuffer(), oldFunc->getResultType()); - IRFunc* newFunc = builder->createFunc(); newFunc->setFullType(newFuncType); - // The above steps have accomplished the "first phase" - // of cloning the function (since `IRFunc`s have no - // operands). - // - // We can now use the shared IR cloning infrastructure - // to perform the second phase of cloning, which will recursively - // clone any nested decorations, blocks, and instructions. - // - cloneInstDecorationsAndChildren( - &cloneEnv, - builder->getModule(), - oldFunc, - newFunc); - - // Now that the main body of existing isntructions have + // Now that the main body of existing instructions have // been cloned into the new function, we can go ahead // and insert all the parameters and body instructions // we built up into the function at the right place. @@ -1474,7 +1479,7 @@ struct SpecializationContext // block (this was an invariant established before // we decided to specialize). // - auto newEntryBlock = newFunc->getFirstBlock(); + auto newEntryBlock = as(cloneEnv.mapOldValToNew[oldFunc->getFirstBlock()]); SLANG_ASSERT(newEntryBlock); // We expect every valid block to have at least one @@ -1497,11 +1502,17 @@ struct SpecializationContext // before the first ordinary instruction (but will come // *after* the parameters by the order of these two loops). // - for (auto newBodyInst : newBodyInsts) + for (auto newBodyInst = tempHeaderBlock->getFirstChild(); newBodyInst;) { + auto next = newBodyInst->next; newBodyInst->insertBefore(newFirstOrdinary); + newBodyInst = next; } + // After moving all param and existential insts in tempHeaderBlock + // it should be empty now and we can remove it. + tempHeaderBlock->removeAndDeallocate(); + // After all this work we have a valid `newFunc` that has been // specialized to match the types at the call site. // diff --git a/tests/language-feature/generics/generic-interface-2.slang b/tests/language-feature/generics/generic-interface-2.slang new file mode 100644 index 0000000000..9a44f679ca --- /dev/null +++ b/tests/language-feature/generics/generic-interface-2.slang @@ -0,0 +1,42 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +interface IFoo +{ + static int sum(int arr[n]); +} + +struct MyType : IFoo +{ + static int sum(int arr[n]) + { + int rs = 0; + for (int i = 0; i < n; i++) + rs += arr[i]; + return rs; + } +} + +int test(IFoo foo) +{ + int arr[n]; + for (int i =0; i < n; i++) + arr[i] = i; + return foo.sum(arr); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + MyType<3> t; + test(t); + // CHECK: 3 + outputBuffer[0] = test(t); + + int arr[3] = {1,2,3}; + // CHECK: 3 + outputBuffer[1] = MyType<3>.sum(arr); +}