Skip to content

Commit cef5067

Browse files
committed
[mlir][linalg] Simplify createWriteOrMaskedWrite (NFC)
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ``` This change addresses a TODO from #141244.
1 parent b4b86a7 commit cef5067

File tree

1 file changed

+37
-76
lines changed

1 file changed

+37
-76
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 37 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,63 +1606,49 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
16061606
/// Creates an optionally masked TransferWriteOp
16071607
///
16081608
/// Generates the following operation:
1609-
/// %res = vector.transfer_write %vectorToStore into %dest
1609+
/// %res = vector.transfer_write %vecToStore into %dest
16101610
///
1611-
/// If the leading N dimensions of the vector to store do not match
1612-
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1613-
/// masking is applied to ensure correctness:
1611+
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
16141612
///
1615-
/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1613+
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
16161614
/// %res = vector.mask %mask {
1617-
/// vector.transfer_write %vectorToStore into %dest
1615+
/// vector.transfer_write %vecToStore into %dest
16181616
/// }
16191617
///
1620-
/// The mask shape is identical to `vectorToStore` (with the element type ==
1618+
/// The mask shape is identical to `vecToStore` (with the element type ==
16211619
/// i1), and the mask values are based on the shape of the `dest` tensor.
16221620
///
16231621
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
16241622
/// is used instead of masking:
16251623
///
1626-
/// %write = vector.transfer_write %vectorToStore into %dest
1624+
/// %write = vector.transfer_write %vecToStore into %dest
16271625
/// in_bounds_flags = (...)
16281626
/// %res = vector.transfer_write %input into %dest
16291627
/// {in_bounds = in_bounds_flags}
16301628
///
1631-
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
1632-
/// to 0.
1633-
///
1634-
/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1635-
/// `valueToStore`.
1636-
/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1637-
/// already provided in `vectorToStore`.
1629+
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1630+
/// are set to 0.
16381631
static Operation *
1639-
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1640-
Value dest,
1641-
ArrayRef<int64_t> inputVecSizesForLeadingDims,
1642-
SmallVector writeIndices = {},
1632+
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
1633+
Value dest, SmallVector writeIndices = {},
16431634
bool useInBoundsInsteadOfMasking = false) {
16441635

16451636
ShapedType destType = cast(dest.getType());
16461637
int64_t destRank = destType.getRank();
16471638
auto destShape = destType.getShape();
16481639

1649-
VectorType vecToStoreType = cast(vectorToStore.getType());
1640+
VectorType vecToStoreType = cast(vecToStore.getType());
16501641
int64_t vecToStoreRank = vecToStoreType.getRank();
16511642
auto vecToStoreShape = vecToStoreType.getShape();
16521643

16531644
// Compute the in_bounds attribute
16541645
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
16551646
if (useInBoundsInsteadOfMasking) {
1656-
// In this case, assume that all the required vector sizes have been
1657-
// provided.
1658-
assert(inputVecSizesForLeadingDims.size() ==
1659-
static_cast<size_t>(vecToStoreType.getRank()) &&
1660-
"Insufficient number of input vector sizes!");
16611647
// Update the inBounds attribute.
16621648
// FIXME: This computation is too weak - it ignores the write indices.
16631649
for (unsigned i = 0; i < vecToStoreRank; i++)
16641650
inBoundsVal[i] =
1665-
(destShape[i] >= inputVecSizesForLeadingDims[i]) &&
1651+
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
16661652
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
16671653
}
16681654

@@ -1678,7 +1664,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16781664
// Generate the xfer_write Op
16791665
Operation *write =
16801666
builder.create(loc,
1681-
/*vector=*/vectorToStore,
1667+
/*vector=*/vecToStore,
16821668
/*source=*/dest,
16831669
/*indices=*/writeIndices,
16841670
/*inBounds=*/inBoundsVal);
@@ -1687,46 +1673,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16871673
if (useInBoundsInsteadOfMasking)
16881674
return write;
16891675

1690-
assert(llvm::none_of(
1691-
destShape.drop_front(inputVecSizesForLeadingDims.size()),
1692-
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1693-
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1694-
1695-
// Check if masking is needed.
1696-
bool needMaskForWrite =
1697-
!llvm::equal(inputVecSizesForLeadingDims,
1698-
destShape.take_front(destRank - vecToStoreRank +
1699-
inputVecSizesForLeadingDims.size()));
1700-
1701-
// If masking is needed, generate the mask and mask the operation.
1702-
if (needMaskForWrite) {
1703-
// Get the mask shape + type. Missing mask dimensions are taken from
1704-
// `vectorToStore`.
1705-
SmallVector<int64_t> writeMaskShape;
1706-
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
1707-
inputVecSizesForLeadingDims.end());
1708-
if (vecToStoreRank >
1709-
static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
1710-
writeMaskShape.append(vecToStoreShape.begin() +
1711-
inputVecSizesForLeadingDims.size(),
1712-
vecToStoreShape.end());
1713-
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1714-
1715-
SmallVector destSizes =
1716-
tensor::getMixedSizes(builder, loc, dest);
1717-
SmallVector maskSizes(destSizes.end() - writeMaskShape.size(),
1718-
destSizes.end());
1719-
1720-
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1721-
writeMaskShape))
1722-
return write;
1723-
1724-
Value maskForWrite = builder.createOrFold(
1725-
loc, writeMaskType, maskSizes);
1726-
write = mlir::vector::maskOperation(builder, write, maskForWrite);
1727-
}
1676+
// Check if masking is needed. If not, exit.
1677+
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1678+
return write;
1679+
1680+
// Compute the mask and mask the write Op.
1681+
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1682+
1683+
SmallVector destSizes =
1684+
tensor::getMixedSizes(builder, loc, dest);
1685+
SmallVector maskSizes(destSizes.end() - vecToStoreRank,
1686+
destSizes.end());
1687+
1688+
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1689+
vecToStoreShape))
1690+
return write;
17281691

1729-
return write;
1692+
Value maskForWrite =
1693+
builder.createOrFold(loc, writeMaskType, maskSizes);
1694+
return mlir::vector::maskOperation(builder, write, maskForWrite);
17301695
}
17311696

17321697
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1826,9 +1791,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18261791
Value dest = rewriter.create(
18271792
loc, reifiedReturnShapes[0],
18281793
transposeOp.getResult().getType().getElementType());
1829-
Operation *write = createWriteOrMaskedWrite(
1830-
rewriter, loc, transposeOp.getResult(), dest,
1831-
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
1794+
Operation *write =
1795+
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
18321796
newResults.push_back(write->getResult(0));
18331797
return success();
18341798
}
@@ -1966,7 +1930,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19661930
shapeCastOp.getResult().getType().getElementType());
19671931
Operation *write = createWriteOrMaskedWrite(
19681932
rewriter, loc, shapeCastOp.getResult(), dest,
1969-
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
19701933
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
19711934
newResults.push_back(write->getResult(0));
19721935
return success();
@@ -1999,9 +1962,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19991962
// Create Xfer write Op
20001963
Value dest = rewriter.create(
20011964
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
2002-
Operation *write = createWriteOrMaskedWrite(
2003-
rewriter, loc, maskedRead, dest,
2004-
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
1965+
Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
20051966
newResults.push_back(write->getResult(0));
20061967
return success();
20071968
}
@@ -3043,9 +3004,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30433004
// Create write
30443005
auto writeIndices =
30453006
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3046-
Operation *write = createWriteOrMaskedWrite(
3047-
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices,
3048-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3007+
Operation *write =
3008+
createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3009+
writeIndices, inputVectorSizes.empty());
30493010

30503011
// 4. Finalize
30513012
newResults.push_back(write->getResult(0));

0 commit comments

Comments
 (0)