Skip to content

Commit 5dfb7bb

Browse files
authored
[mlir][linalg] Simplify createWriteOrMaskedWrite (NFC) (#141567)
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary - vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ```
1 parent 7119a0f commit 5dfb7bb

File tree

1 file changed

+37
-76
lines changed

1 file changed

+37
-76
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 37 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,63 +1605,49 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
16051605
/// Creates an optionally masked TransferWriteOp
16061606
///
16071607
/// Generates the following operation:
1608-
/// %res = vector.transfer_write %vectorToStore into %dest
1608+
/// %res = vector.transfer_write %vecToStore into %dest
16091609
///
1610-
/// If the leading N dimensions of the vector to store do not match
1611-
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1612-
/// masking is applied to ensure correctness:
1610+
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
16131611
///
1614-
/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1612+
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
16151613
/// %res = vector.mask %mask {
1616-
/// vector.transfer_write %vectorToStore into %dest
1614+
/// vector.transfer_write %vecToStore into %dest
16171615
/// }
16181616
///
1619-
/// The mask shape is identical to `vectorToStore` (with the element type ==
1617+
/// The mask shape is identical to `vecToStore` (with the element type ==
16201618
/// i1), and the mask values are based on the shape of the `dest` tensor.
16211619
///
16221620
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
16231621
/// is used instead of masking:
16241622
///
1625-
/// %write = vector.transfer_write %vectorToStore into %dest
1623+
/// %write = vector.transfer_write %vecToStore into %dest
16261624
/// in_bounds_flags = (...)
16271625
/// %res = vector.transfer_write %input into %dest
16281626
/// {in_bounds = in_bounds_flags}
16291627
///
1630-
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
1631-
/// to 0.
1632-
///
1633-
/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1634-
/// `valueToStore`.
1635-
/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1636-
/// already provided in `vectorToStore`.
1628+
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1629+
/// are set to 0.
16371630
static Operation *
1638-
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1639-
Value dest,
1640-
ArrayRef<int64_t> inputVecSizesForLeadingDims,
1641-
SmallVector writeIndices = {},
1631+
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
1632+
Value dest, SmallVector writeIndices = {},
16421633
bool useInBoundsInsteadOfMasking = false) {
16431634

16441635
ShapedType destType = cast(dest.getType());
16451636
int64_t destRank = destType.getRank();
16461637
auto destShape = destType.getShape();
16471638

1648-
VectorType vecToStoreType = cast(vectorToStore.getType());
1639+
VectorType vecToStoreType = cast(vecToStore.getType());
16491640
int64_t vecToStoreRank = vecToStoreType.getRank();
16501641
auto vecToStoreShape = vecToStoreType.getShape();
16511642

16521643
// Compute the in_bounds attribute
16531644
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
16541645
if (useInBoundsInsteadOfMasking) {
1655-
// In this case, assume that all the required vector sizes have been
1656-
// provided.
1657-
assert(inputVecSizesForLeadingDims.size() ==
1658-
static_cast<size_t>(vecToStoreType.getRank()) &&
1659-
"Insufficient number of input vector sizes!");
16601646
// Update the inBounds attribute.
16611647
// FIXME: This computation is too weak - it ignores the write indices.
16621648
for (unsigned i = 0; i < vecToStoreRank; i++)
16631649
inBoundsVal[i] =
1664-
(destShape[i] >= inputVecSizesForLeadingDims[i]) &&
1650+
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
16651651
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
16661652
}
16671653

@@ -1677,7 +1663,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16771663
// Generate the xfer_write Op
16781664
Operation *write =
16791665
builder.create(loc,
1680-
/*vector=*/vectorToStore,
1666+
/*vector=*/vecToStore,
16811667
/*source=*/dest,
16821668
/*indices=*/writeIndices,
16831669
/*inBounds=*/inBoundsVal);
@@ -1686,46 +1672,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16861672
if (useInBoundsInsteadOfMasking)
16871673
return write;
16881674

1689-
assert(llvm::none_of(
1690-
destShape.drop_front(inputVecSizesForLeadingDims.size()),
1691-
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1692-
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1693-
1694-
// Check if masking is needed.
1695-
bool needMaskForWrite =
1696-
!llvm::equal(inputVecSizesForLeadingDims,
1697-
destShape.take_front(destRank - vecToStoreRank +
1698-
inputVecSizesForLeadingDims.size()));
1699-
1700-
// If masking is needed, generate the mask and mask the operation.
1701-
if (needMaskForWrite) {
1702-
// Get the mask shape + type. Missing mask dimensions are taken from
1703-
// `vectorToStore`.
1704-
SmallVector<int64_t> writeMaskShape;
1705-
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
1706-
inputVecSizesForLeadingDims.end());
1707-
if (vecToStoreRank >
1708-
static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
1709-
writeMaskShape.append(vecToStoreShape.begin() +
1710-
inputVecSizesForLeadingDims.size(),
1711-
vecToStoreShape.end());
1712-
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1713-
1714-
SmallVector destSizes =
1715-
tensor::getMixedSizes(builder, loc, dest);
1716-
SmallVector maskSizes(destSizes.end() - writeMaskShape.size(),
1717-
destSizes.end());
1718-
1719-
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1720-
writeMaskShape))
1721-
return write;
1722-
1723-
Value maskForWrite = builder.createOrFold(
1724-
loc, writeMaskType, maskSizes);
1725-
write = mlir::vector::maskOperation(builder, write, maskForWrite);
1726-
}
1675+
// Check if masking is needed. If not, exit.
1676+
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1677+
return write;
1678+
1679+
// Compute the mask and mask the write Op.
1680+
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1681+
1682+
SmallVector destSizes =
1683+
tensor::getMixedSizes(builder, loc, dest);
1684+
SmallVector maskSizes(destSizes.end() - vecToStoreRank,
1685+
destSizes.end());
1686+
1687+
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1688+
vecToStoreShape))
1689+
return write;
17271690

1728-
return write;
1691+
Value maskForWrite =
1692+
builder.createOrFold(loc, writeMaskType, maskSizes);
1693+
return mlir::vector::maskOperation(builder, write, maskForWrite);
17291694
}
17301695

17311696
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1825,9 +1790,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18251790
Value dest = rewriter.create(
18261791
loc, reifiedReturnShapes[0],
18271792
transposeOp.getResult().getType().getElementType());
1828-
Operation *write = createWriteOrMaskedWrite(
1829-
rewriter, loc, transposeOp.getResult(), dest,
1830-
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
1793+
Operation *write =
1794+
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
18311795
newResults.push_back(write->getResult(0));
18321796
return success();
18331797
}
@@ -1965,7 +1929,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19651929
shapeCastOp.getResult().getType().getElementType());
19661930
Operation *write = createWriteOrMaskedWrite(
19671931
rewriter, loc, shapeCastOp.getResult(), dest,
1968-
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
19691932
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
19701933
newResults.push_back(write->getResult(0));
19711934
return success();
@@ -1998,9 +1961,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19981961
// Create Xfer write Op
19991962
Value dest = rewriter.create(
20001963
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
2001-
Operation *write = createWriteOrMaskedWrite(
2002-
rewriter, loc, maskedRead, dest,
2003-
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
1964+
Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
20041965
newResults.push_back(write->getResult(0));
20051966
return success();
20061967
}
@@ -3040,9 +3001,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30403001
// Create write
30413002
auto writeIndices =
30423003
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3043-
Operation *write = createWriteOrMaskedWrite(
3044-
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices,
3045-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3004+
Operation *write =
3005+
createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3006+
writeIndices, inputVectorSizes.empty());
30463007

30473008
// 4. Finalize
30483009
newResults.push_back(write->getResult(0));

0 commit comments

Comments
 (0)