Skip to content

Commit ca24a26

Browse files
committed
[mlir][linalg] Simplify createWriteOrMaskedWrite (NFC)
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ``` This change addresses a TODO from #141244.
1 parent 82cc2fe commit ca24a26

File tree

1 file changed

+40
-78
lines changed

1 file changed

+40
-78
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 40 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,61 +1606,46 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
16061606
/// Creates an optionally masked TransferWriteOp
16071607
///
16081608
/// Generates the following operation:
1609-
/// %res = vector.transfer_write %vectorToStore into %dest
1609+
/// %res = vector.transfer_write %vecToStore into %dest
16101610
///
1611-
/// If the leading N dimensions of the vector to store do not match
1612-
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1613-
/// masking is applied to ensure correctness:
1611+
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
16141612
///
1615-
/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1613+
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
16161614
/// %res = vector.mask %mask {
1617-
/// vector.transfer_write %vectorToStore into %dest
1615+
/// vector.transfer_write %vecToStore into %dest
16181616
/// }
16191617
///
1620-
/// The mask shape is identical to `vectorToStore` (with the element type ==
1618+
/// The mask shape is identical to `vecToStore` (with the element type ==
16211619
/// i1), and the mask values are based on the shape of the `dest` tensor.
16221620
///
16231621
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
16241622
/// is used instead of masking:
16251623
///
1626-
/// %write = vector.transfer_write %vectorToStore into %dest
1624+
/// %write = vector.transfer_write %vecToStore into %dest
16271625
/// in_bounds_flags = (...)
16281626
/// %res = vector.transfer_write %input into %dest
16291627
/// {in_bounds = in_bounds_flags}
16301628
///
1631-
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
1632-
/// to 0.
1633-
///
1634-
/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1635-
/// `valueToStore`.
1636-
/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1637-
/// already provided in `vectorToStore`.
1629+
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1630+
/// are set to 0.
16381631
static Operation *
1639-
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1640-
Value dest,
1641-
ArrayRef<int64_t> inputVecSizesForLeadingDims,
1642-
SmallVector writeIndices = {},
1632+
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
1633+
Value dest, SmallVector writeIndices = {},
16431634
bool useInBoundsInsteadOfMasking = false) {
16441635

16451636
ShapedType destType = cast(dest.getType());
16461637
int64_t destRank = destType.getRank();
16471638
auto destShape = destType.getShape();
16481639

1649-
VectorType vecToStoreType = cast(vectorToStore.getType());
1640+
VectorType vecToStoreType = cast(vecToStore.getType());
16501641
int64_t vecToStoreRank = vecToStoreType.getRank();
16511642
auto vecToStoreShape = vecToStoreType.getShape();
16521643

16531644
// Compute the in_bounds attribute
16541645
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
16551646
if (useInBoundsInsteadOfMasking) {
1656-
// In this case, assume that all the required vector sizes have been
1657-
// provided.
1658-
assert(inputVecSizesForLeadingDims.size() ==
1659-
static_cast<size_t>(vecToStoreType.getRank()) &&
1660-
"Insufficient number of input vector sizes!");
1661-
// Update the inBounds attribute.
16621647
for (unsigned i = 0; i < destRank; i++)
1663-
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
1648+
inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
16641649
!ShapedType::isDynamic(destShape[i]);
16651650
}
16661651

@@ -1676,7 +1661,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16761661
// Generate the xfer_write Op
16771662
Operation *write =
16781663
builder.create(loc,
1679-
/*vector=*/vectorToStore,
1664+
/*vector=*/vecToStore,
16801665
/*source=*/dest,
16811666
/*indices=*/writeIndices,
16821667
/*inBounds=*/inBoundsVal);
@@ -1685,46 +1670,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16851670
if (useInBoundsInsteadOfMasking)
16861671
return write;
16871672

1688-
assert(llvm::none_of(
1689-
destShape.drop_front(inputVecSizesForLeadingDims.size()),
1690-
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1691-
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1692-
1693-
// Check if masking is needed.
1694-
bool needMaskForWrite =
1695-
!llvm::equal(inputVecSizesForLeadingDims,
1696-
destShape.take_front(destRank - vecToStoreRank +
1697-
inputVecSizesForLeadingDims.size()));
1698-
1699-
// If masking is needed, generate the mask and mask the operation.
1700-
if (needMaskForWrite) {
1701-
// Get the mask shape + type. Missing mask dimensions are taken from
1702-
// `vectorToStore`.
1703-
SmallVector<int64_t> writeMaskShape;
1704-
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
1705-
inputVecSizesForLeadingDims.end());
1706-
if (vecToStoreRank >
1707-
static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
1708-
writeMaskShape.append(vecToStoreShape.begin() +
1709-
inputVecSizesForLeadingDims.size(),
1710-
vecToStoreShape.end());
1711-
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1712-
1713-
SmallVector destSizes =
1714-
tensor::getMixedSizes(builder, loc, dest);
1715-
SmallVector maskSizes(destSizes.end() - writeMaskShape.size(),
1716-
destSizes.end());
1717-
1718-
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1719-
writeMaskShape))
1720-
return write;
1721-
1722-
Value maskForWrite = builder.createOrFold(
1723-
loc, writeMaskType, maskSizes);
1724-
write = mlir::vector::maskOperation(builder, write, maskForWrite);
1725-
}
1673+
// Check if masking is needed. If not, exit.
1674+
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1675+
return write;
1676+
1677+
// Compute the mask and mask the write Op.
1678+
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1679+
1680+
SmallVector destSizes =
1681+
tensor::getMixedSizes(builder, loc, dest);
1682+
SmallVector maskSizes(destSizes.end() - vecToStoreRank,
1683+
destSizes.end());
1684+
1685+
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1686+
vecToStoreShape))
1687+
return write;
17261688

1727-
return write;
1689+
Value maskForWrite =
1690+
builder.createOrFold(loc, writeMaskType, maskSizes);
1691+
return mlir::vector::maskOperation(builder, write, maskForWrite);
17281692
}
17291693

17301694
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1824,10 +1788,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18241788
Value dest = rewriter.create(
18251789
loc, reifiedReturnShapes[0],
18261790
transposeOp.getResult().getType().getElementType());
1827-
Operation *write = createWriteOrMaskedWrite(
1828-
rewriter, loc, transposeOp.getResult(), dest,
1829-
/*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
1830-
/*useInBoundsInsteadOfMasking=*/false);
1791+
Operation *write =
1792+
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
1793+
/*writeIndices=*/{},
1794+
/*useInBoundsInsteadOfMasking=*/false);
18311795
newResults.push_back(write->getResult(0));
18321796
return success();
18331797
}
@@ -1965,7 +1929,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19651929
shapeCastOp.getResult().getType().getElementType());
19661930
Operation *write = createWriteOrMaskedWrite(
19671931
rewriter, loc, shapeCastOp.getResult(), dest,
1968-
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
19691932
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
19701933
newResults.push_back(write->getResult(0));
19711934
return success();
@@ -1998,10 +1961,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19981961
// Create Xfer write Op
19991962
Value dest = rewriter.create(
20001963
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
2001-
Operation *write = createWriteOrMaskedWrite(
2002-
rewriter, loc, maskedRead, dest,
2003-
/*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
2004-
/*useInBoundsInsteadOfMasking=*/false);
1964+
Operation *write =
1965+
createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
1966+
/*useInBoundsInsteadOfMasking=*/false);
20051967
newResults.push_back(write->getResult(0));
20061968
return success();
20071969
}
@@ -3057,8 +3019,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30573019
// Create write
30583020
auto writeIndices =
30593021
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3060-
Operation *write = createWriteOrMaskedWrite(
3061-
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
3022+
Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
3023+
sliceOp.getDest(), writeIndices);
30623024

30633025
// 4. Finalize
30643026
newResults.push_back(write->getResult(0));

0 commit comments

Comments
 (0)