Skip to content

[mlir][linalg] Simplify createWriteOrMaskedWrite (NFC) #141567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 37 additions & 76 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1606,63 +1606,49 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
/// %res = vector.transfer_write %vectorToStore into %dest
/// %res = vector.transfer_write %vecToStore into %dest
///
/// If the leading N dimensions of the vector to store do not match
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
/// masking is applied to ensure correctness:
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
/// vector.transfer_write %vectorToStore into %dest
/// vector.transfer_write %vecToStore into %dest
/// }
///
/// The mask shape is identical to `vectorToStore` (with the element type ==
/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
/// %write = vector.transfer_write %vectorToStore into %dest
/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
/// to 0.
///
/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
/// `valueToStore`.
/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
/// already provided in `vectorToStore`.
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
/// are set to 0.
static Operation *
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
Value dest,
ArrayRef inputVecSizesForLeadingDims,
SmallVector writeIndices = {},
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
Value dest, SmallVector writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {

ShapedType destType = cast(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();

VectorType vecToStoreType = cast(vectorToStore.getType());
VectorType vecToStoreType = cast(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();

// Compute the in_bounds attribute
SmallVector inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
// In this case, assume that all the required vector sizes have been
// provided.
assert(inputVecSizesForLeadingDims.size() ==
static_cast(vecToStoreType.getRank()) &&
"Insufficient number of input vector sizes!");
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the write indices.
for (unsigned i = 0; i < vecToStoreRank; i++)
inBoundsVal[i] =
(destShape[i] >= inputVecSizesForLeadingDims[i]) &&
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
}

Expand All @@ -1678,7 +1664,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create(loc,
/*vector=*/vectorToStore,
/*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
Expand All @@ -1687,46 +1673,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;

assert(llvm::none_of(
destShape.drop_front(inputVecSizesForLeadingDims.size()),
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");

// Check if masking is needed.
bool needMaskForWrite =
!llvm::equal(inputVecSizesForLeadingDims,
destShape.take_front(destRank - vecToStoreRank +
inputVecSizesForLeadingDims.size()));

// If masking is needed, generate the mask and mask the operation.
if (needMaskForWrite) {
// Get the mask shape + type. Missing mask dimensions are taken from
// `vectorToStore`.
SmallVector writeMaskShape;
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
inputVecSizesForLeadingDims.end());
if (vecToStoreRank >
static_cast(inputVecSizesForLeadingDims.size()))
writeMaskShape.append(vecToStoreShape.begin() +
inputVecSizesForLeadingDims.size(),
vecToStoreShape.end());
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());

SmallVector destSizes =
tensor::getMixedSizes(builder, loc, dest);
SmallVector maskSizes(destSizes.end() - writeMaskShape.size(),
destSizes.end());

if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
writeMaskShape))
return write;

Value maskForWrite = builder.createOrFold(
loc, writeMaskType, maskSizes);
write = mlir::vector::maskOperation(builder, write, maskForWrite);
}
// Check if masking is needed. If not, exit.
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
return write;

// Compute the mask and mask the write Op.
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());

SmallVector destSizes =
tensor::getMixedSizes(builder, loc, dest);
SmallVector maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());

if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
vecToStoreShape))
return write;

return write;
Value maskForWrite =
builder.createOrFold(loc, writeMaskType, maskSizes);
return mlir::vector::maskOperation(builder, write, maskForWrite);
}

/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
Expand Down Expand Up @@ -1826,9 +1791,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), dest,
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -1966,7 +1930,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
Expand Down Expand Up @@ -1999,9 +1962,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, maskedRead, dest,
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -3043,9 +3004,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices,
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
writeIndices, inputVectorSizes.empty());

// 4. Finalize
newResults.push_back(write->getResult(0));
Expand Down
Loading