Skip to content

[mlir][linalg] Simplify createWriteOrMaskedWrite (NFC) #141567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2025

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented May 27, 2025

This patch removes inputVecSizesForLeadingDims from the parameter list
of createWriteOrMaskedWrite. That argument is unnecessary — vector sizes
can be obtained from the vecToStore parameter. Since this doesn't change
behavior or test results, it's marked as NFC.

Additional cleanups:

  • Renamed vectorToStore to vecToStore for consistency and brevity.
  • Rewrote a conditional at the end of the function to use early exit,
    improving readability:
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);

@banach-space banach-space changed the base branch from main to users/banach-space/vector/update_vectorize_insert_slice May 27, 2025 09:12
@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Andrzej Warzyński (banach-space)

Changes
  • [[mlir][linalg] Refactor vectorization hooks to improve code reuse
  • [mlir][linalg] Simplify createWriteOrMaskedWrite (NFC)

Full diff: https://github.com/llvm/llvm-project/pull/141567.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+40-78)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
-///   %res = vector.transfer_write %vectorToStore into %dest
+///   %res = vector.transfer_write %vecToStore into %dest
 ///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+///   %mask = vector.create_mask(%destShape) : %vecToStoreShape
 ///   %res = vector.mask %mask {
-///     vector.transfer_write %vectorToStore into %dest
+///     vector.transfer_write %vecToStore into %dest
 ///   }
 ///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
 /// i1), and the mask values are based on the shape of the `dest` tensor.
 ///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
-///   %write = vector.transfer_write %vectorToStore into %dest
+///   %write = vector.transfer_write %vecToStore into %dest
 ///   in_bounds_flags = (...)
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
 static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
-                         Value dest,
-                         ArrayRef inputVecSizesForLeadingDims,
-                         SmallVector writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+                         Value dest, SmallVector writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast(dest.getType());
   int64_t destRank = destType.getRank();
   auto destShape = destType.getShape();
 
-  VectorType vecToStoreType = cast(vectorToStore.getType());
+  VectorType vecToStoreType = cast(vecToStore.getType());
   int64_t vecToStoreRank = vecToStoreType.getRank();
   auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
   SmallVector inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
-    // In this case, assume that all the required vector sizes have been
-    // provided.
-    assert(inputVecSizesForLeadingDims.size() ==
-               static_cast(vecToStoreType.getRank()) &&
-           "Insufficient number of input vector sizes!");
-    // Update the inBounds attribute.
     for (unsigned i = 0; i < destRank; i++)
-      inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+      inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   // Generate the xfer_write Op
   Operation *write =
       builder.create(loc,
-                                              /*vector=*/vectorToStore,
+                                              /*vector=*/vecToStore,
                                               /*source=*/dest,
                                               /*indices=*/writeIndices,
                                               /*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   if (useInBoundsInsteadOfMasking)
     return write;
 
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
-  // Check if masking is needed.
-  bool needMaskForWrite =
-      !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(destRank - vecToStoreRank +
-                                        inputVecSizesForLeadingDims.size()));
-
-  // If masking is needed, generate the mask and mask the operation.
-  if (needMaskForWrite) {
-    // Get the mask shape + type. Missing mask dimensions are taken from
-    // `vectorToStore`.
-    SmallVector writeMaskShape;
-    writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
-                          inputVecSizesForLeadingDims.end());
-    if (vecToStoreRank >
-        static_cast(inputVecSizesForLeadingDims.size()))
-      writeMaskShape.append(vecToStoreShape.begin() +
-                                inputVecSizesForLeadingDims.size(),
-                            vecToStoreShape.end());
-    auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
-    SmallVector destSizes =
-        tensor::getMixedSizes(builder, loc, dest);
-    SmallVector maskSizes(destSizes.end() - writeMaskShape.size(),
-                                        destSizes.end());
-
-    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
-                                writeMaskShape))
-      return write;
-
-    Value maskForWrite = builder.createOrFold(
-        loc, writeMaskType, maskSizes);
-    write = mlir::vector::maskOperation(builder, write, maskForWrite);
-  }
+  // Check if masking is needed. If not, exit.
+  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+    return write;
+
+  // Compute the mask and mask the write Op.
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+  SmallVector destSizes =
+      tensor::getMixedSizes(builder, loc, dest);
+  SmallVector maskSizes(destSizes.end() - vecToStoreRank,
+                                      destSizes.end());
+
+  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                              vecToStoreShape))
+    return write;
 
-  return write;
+  Value maskForWrite =
+      builder.createOrFold(loc, writeMaskType, maskSizes);
+  return mlir::vector::maskOperation(builder, write, maskForWrite);
 }
 
 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, transposeOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+                               /*writeIndices=*/{},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
       shapeCastOp.getResult().getType().getElementType());
   Operation *write = createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
       /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, maskedRead, dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   // Create write
   auto writeIndices =
       getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+  Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+                                              sliceOp.getDest(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));

@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes
  • [[mlir][linalg] Refactor vectorization hooks to improve code reuse
  • [mlir][linalg] Simplify createWriteOrMaskedWrite (NFC)

Full diff: https://github.com/llvm/llvm-project/pull/141567.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+40-78)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
-///   %res = vector.transfer_write %vectorToStore into %dest
+///   %res = vector.transfer_write %vecToStore into %dest
 ///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+///   %mask = vector.create_mask(%destShape) : %vecToStoreShape
 ///   %res = vector.mask %mask {
-///     vector.transfer_write %vectorToStore into %dest
+///     vector.transfer_write %vecToStore into %dest
 ///   }
 ///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
 /// i1), and the mask values are based on the shape of the `dest` tensor.
 ///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
-///   %write = vector.transfer_write %vectorToStore into %dest
+///   %write = vector.transfer_write %vecToStore into %dest
 ///   in_bounds_flags = (...)
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
 static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
-                         Value dest,
-                         ArrayRef inputVecSizesForLeadingDims,
-                         SmallVector writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+                         Value dest, SmallVector writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast(dest.getType());
   int64_t destRank = destType.getRank();
   auto destShape = destType.getShape();
 
-  VectorType vecToStoreType = cast(vectorToStore.getType());
+  VectorType vecToStoreType = cast(vecToStore.getType());
   int64_t vecToStoreRank = vecToStoreType.getRank();
   auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
   SmallVector inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
-    // In this case, assume that all the required vector sizes have been
-    // provided.
-    assert(inputVecSizesForLeadingDims.size() ==
-               static_cast(vecToStoreType.getRank()) &&
-           "Insufficient number of input vector sizes!");
-    // Update the inBounds attribute.
     for (unsigned i = 0; i < destRank; i++)
-      inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+      inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   // Generate the xfer_write Op
   Operation *write =
       builder.create(loc,
-                                              /*vector=*/vectorToStore,
+                                              /*vector=*/vecToStore,
                                               /*source=*/dest,
                                               /*indices=*/writeIndices,
                                               /*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   if (useInBoundsInsteadOfMasking)
     return write;
 
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
-  // Check if masking is needed.
-  bool needMaskForWrite =
-      !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(destRank - vecToStoreRank +
-                                        inputVecSizesForLeadingDims.size()));
-
-  // If masking is needed, generate the mask and mask the operation.
-  if (needMaskForWrite) {
-    // Get the mask shape + type. Missing mask dimensions are taken from
-    // `vectorToStore`.
-    SmallVector writeMaskShape;
-    writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
-                          inputVecSizesForLeadingDims.end());
-    if (vecToStoreRank >
-        static_cast(inputVecSizesForLeadingDims.size()))
-      writeMaskShape.append(vecToStoreShape.begin() +
-                                inputVecSizesForLeadingDims.size(),
-                            vecToStoreShape.end());
-    auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
-    SmallVector destSizes =
-        tensor::getMixedSizes(builder, loc, dest);
-    SmallVector maskSizes(destSizes.end() - writeMaskShape.size(),
-                                        destSizes.end());
-
-    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
-                                writeMaskShape))
-      return write;
-
-    Value maskForWrite = builder.createOrFold(
-        loc, writeMaskType, maskSizes);
-    write = mlir::vector::maskOperation(builder, write, maskForWrite);
-  }
+  // Check if masking is needed. If not, exit.
+  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+    return write;
+
+  // Compute the mask and mask the write Op.
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+  SmallVector destSizes =
+      tensor::getMixedSizes(builder, loc, dest);
+  SmallVector maskSizes(destSizes.end() - vecToStoreRank,
+                                      destSizes.end());
+
+  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                              vecToStoreShape))
+    return write;
 
-  return write;
+  Value maskForWrite =
+      builder.createOrFold(loc, writeMaskType, maskSizes);
+  return mlir::vector::maskOperation(builder, write, maskForWrite);
 }
 
 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, transposeOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+                               /*writeIndices=*/{},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
       shapeCastOp.getResult().getType().getElementType());
   Operation *write = createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
       /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, maskedRead, dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   // Create write
   auto writeIndices =
       getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+  Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+                                              sliceOp.getDest(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));

@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [[mlir][linalg] Refactor vectorization hooks to improve code reuse
  • [mlir][linalg] Simplify createWriteOrMaskedWrite (NFC)

Full diff: https://github.com/llvm/llvm-project/pull/141567.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+40-78)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
-///   %res = vector.transfer_write %vectorToStore into %dest
+///   %res = vector.transfer_write %vecToStore into %dest
 ///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+///   %mask = vector.create_mask(%destShape) : %vecToStoreShape
 ///   %res = vector.mask %mask {
-///     vector.transfer_write %vectorToStore into %dest
+///     vector.transfer_write %vecToStore into %dest
 ///   }
 ///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
 /// i1), and the mask values are based on the shape of the `dest` tensor.
 ///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
-///   %write = vector.transfer_write %vectorToStore into %dest
+///   %write = vector.transfer_write %vecToStore into %dest
 ///   in_bounds_flags = (...)
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
 static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
-                         Value dest,
-                         ArrayRef inputVecSizesForLeadingDims,
-                         SmallVector writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+                         Value dest, SmallVector writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast(dest.getType());
   int64_t destRank = destType.getRank();
   auto destShape = destType.getShape();
 
-  VectorType vecToStoreType = cast(vectorToStore.getType());
+  VectorType vecToStoreType = cast(vecToStore.getType());
   int64_t vecToStoreRank = vecToStoreType.getRank();
   auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
   SmallVector inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
-    // In this case, assume that all the required vector sizes have been
-    // provided.
-    assert(inputVecSizesForLeadingDims.size() ==
-               static_cast(vecToStoreType.getRank()) &&
-           "Insufficient number of input vector sizes!");
-    // Update the inBounds attribute.
     for (unsigned i = 0; i < destRank; i++)
-      inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+      inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   // Generate the xfer_write Op
   Operation *write =
       builder.create(loc,
-                                              /*vector=*/vectorToStore,
+                                              /*vector=*/vecToStore,
                                               /*source=*/dest,
                                               /*indices=*/writeIndices,
                                               /*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   if (useInBoundsInsteadOfMasking)
     return write;
 
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
-  // Check if masking is needed.
-  bool needMaskForWrite =
-      !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(destRank - vecToStoreRank +
-                                        inputVecSizesForLeadingDims.size()));
-
-  // If masking is needed, generate the mask and mask the operation.
-  if (needMaskForWrite) {
-    // Get the mask shape + type. Missing mask dimensions are taken from
-    // `vectorToStore`.
-    SmallVector writeMaskShape;
-    writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
-                          inputVecSizesForLeadingDims.end());
-    if (vecToStoreRank >
-        static_cast(inputVecSizesForLeadingDims.size()))
-      writeMaskShape.append(vecToStoreShape.begin() +
-                                inputVecSizesForLeadingDims.size(),
-                            vecToStoreShape.end());
-    auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
-    SmallVector destSizes =
-        tensor::getMixedSizes(builder, loc, dest);
-    SmallVector maskSizes(destSizes.end() - writeMaskShape.size(),
-                                        destSizes.end());
-
-    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
-                                writeMaskShape))
-      return write;
-
-    Value maskForWrite = builder.createOrFold(
-        loc, writeMaskType, maskSizes);
-    write = mlir::vector::maskOperation(builder, write, maskForWrite);
-  }
+  // Check if masking is needed. If not, exit.
+  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+    return write;
+
+  // Compute the mask and mask the write Op.
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+  SmallVector destSizes =
+      tensor::getMixedSizes(builder, loc, dest);
+  SmallVector maskSizes(destSizes.end() - vecToStoreRank,
+                                      destSizes.end());
+
+  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                              vecToStoreShape))
+    return write;
 
-  return write;
+  Value maskForWrite =
+      builder.createOrFold(loc, writeMaskType, maskSizes);
+  return mlir::vector::maskOperation(builder, write, maskForWrite);
 }
 
 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, transposeOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+                               /*writeIndices=*/{},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
       shapeCastOp.getResult().getType().getElementType());
   Operation *write = createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
       /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, maskedRead, dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   // Create write
   auto writeIndices =
       getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+  Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+                                              sliceOp.getDest(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));

@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  • [[mlir][linalg] Refactor vectorization hooks to improve code reuse
  • [mlir][linalg] Simplify createWriteOrMaskedWrite (NFC)

Full diff: https://github.com/llvm/llvm-project/pull/141567.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+40-78)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
-///   %res = vector.transfer_write %vectorToStore into %dest
+///   %res = vector.transfer_write %vecToStore into %dest
 ///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+///   %mask = vector.create_mask(%destShape) : %vecToStoreShape
 ///   %res = vector.mask %mask {
-///     vector.transfer_write %vectorToStore into %dest
+///     vector.transfer_write %vecToStore into %dest
 ///   }
 ///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
 /// i1), and the mask values are based on the shape of the `dest` tensor.
 ///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
-///   %write = vector.transfer_write %vectorToStore into %dest
+///   %write = vector.transfer_write %vecToStore into %dest
 ///   in_bounds_flags = (...)
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
 static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
-                         Value dest,
-                         ArrayRef inputVecSizesForLeadingDims,
-                         SmallVector writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+                         Value dest, SmallVector writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast(dest.getType());
   int64_t destRank = destType.getRank();
   auto destShape = destType.getShape();
 
-  VectorType vecToStoreType = cast(vectorToStore.getType());
+  VectorType vecToStoreType = cast(vecToStore.getType());
   int64_t vecToStoreRank = vecToStoreType.getRank();
   auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
   SmallVector inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
-    // In this case, assume that all the required vector sizes have been
-    // provided.
-    assert(inputVecSizesForLeadingDims.size() ==
-               static_cast(vecToStoreType.getRank()) &&
-           "Insufficient number of input vector sizes!");
-    // Update the inBounds attribute.
     for (unsigned i = 0; i < destRank; i++)
-      inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+      inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   // Generate the xfer_write Op
   Operation *write =
       builder.create(loc,
-                                              /*vector=*/vectorToStore,
+                                              /*vector=*/vecToStore,
                                               /*source=*/dest,
                                               /*indices=*/writeIndices,
                                               /*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   if (useInBoundsInsteadOfMasking)
     return write;
 
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
-  // Check if masking is needed.
-  bool needMaskForWrite =
-      !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(destRank - vecToStoreRank +
-                                        inputVecSizesForLeadingDims.size()));
-
-  // If masking is needed, generate the mask and mask the operation.
-  if (needMaskForWrite) {
-    // Get the mask shape + type. Missing mask dimensions are taken from
-    // `vectorToStore`.
-    SmallVector writeMaskShape;
-    writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
-                          inputVecSizesForLeadingDims.end());
-    if (vecToStoreRank >
-        static_cast(inputVecSizesForLeadingDims.size()))
-      writeMaskShape.append(vecToStoreShape.begin() +
-                                inputVecSizesForLeadingDims.size(),
-                            vecToStoreShape.end());
-    auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
-    SmallVector destSizes =
-        tensor::getMixedSizes(builder, loc, dest);
-    SmallVector maskSizes(destSizes.end() - writeMaskShape.size(),
-                                        destSizes.end());
-
-    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
-                                writeMaskShape))
-      return write;
-
-    Value maskForWrite = builder.createOrFold(
-        loc, writeMaskType, maskSizes);
-    write = mlir::vector::maskOperation(builder, write, maskForWrite);
-  }
+  // Check if masking is needed. If not, exit.
+  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+    return write;
+
+  // Compute the mask and mask the write Op.
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+  SmallVector destSizes =
+      tensor::getMixedSizes(builder, loc, dest);
+  SmallVector maskSizes(destSizes.end() - vecToStoreRank,
+                                      destSizes.end());
+
+  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                              vecToStoreShape))
+    return write;
 
-  return write;
+  Value maskForWrite =
+      builder.createOrFold(loc, writeMaskType, maskSizes);
+  return mlir::vector::maskOperation(builder, write, maskForWrite);
 }
 
 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, transposeOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+                               /*writeIndices=*/{},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
       shapeCastOp.getResult().getType().getElementType());
   Operation *write = createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
       /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, maskedRead, dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   // Create write
   auto writeIndices =
       getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+  Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+                                              sliceOp.getDest(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));

@banach-space banach-space changed the title users/banach space/vector/update create write [mlir][linalg] Simplify createWriteOrMaskedWrite (NFC) May 27, 2025
@banach-space banach-space requested a review from Max191 May 27, 2025 09:13
Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice clean up, thanks!

@banach-space banach-space force-pushed the users/banach-space/vector/update_vectorize_insert_slice branch from 42b1783 to 373036e Compare May 30, 2025 10:17
Base automatically changed from users/banach-space/vector/update_vectorize_insert_slice to main June 7, 2025 18:25
@banach-space banach-space force-pushed the users/banach-space/vector/update_create_write branch from ca24a26 to edcc604 Compare June 7, 2025 19:33
Copy link

github-actions bot commented Jun 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes
can be obtained from the `vecToStore` parameter. Since this doesn't change
behavior or test results, it's marked as NFC.

Additional cleanups:
  * Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
  * Rewrote a conditional at the end of the function to use early exit,
    improving readability:

```cpp
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);
```

This change addresses a TODO from #141244.
@banach-space banach-space force-pushed the users/banach-space/vector/update_create_write branch from edcc604 to cef5067 Compare June 7, 2025 20:07
@banach-space banach-space merged commit 5dfb7bb into main Jun 8, 2025
7 checks passed
@banach-space banach-space deleted the users/banach-space/vector/update_create_write branch June 8, 2025 11:36
omkar-mohanty pushed a commit to omkar-mohanty/llvm-project that referenced this pull request Jun 8, 2025
This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary - vector
sizes can be obtained from the `vecToStore` parameter. Since this doesn't
change behavior or test results, it's marked as NFC.

Additional cleanups:
  * Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
  * Rewrote a conditional at the end of the function to use early exit,
    improving readability:

```cpp
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);
```
rorth pushed a commit to rorth/llvm-project that referenced this pull request Jun 11, 2025
This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary - vector
sizes can be obtained from the `vecToStore` parameter. Since this doesn't
change behavior or test results, it's marked as NFC.

Additional cleanups:
  * Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
  * Rewrote a conditional at the end of the function to use early exit,
    improving readability:

```cpp
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);
```
DhruvSrivastavaX pushed a commit to DhruvSrivastavaX/lldb-for-aix that referenced this pull request Jun 12, 2025
This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary - vector
sizes can be obtained from the `vecToStore` parameter. Since this doesn't
change behavior or test results, it's marked as NFC.

Additional cleanups:
  * Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
  * Rewrote a conditional at the end of the function to use early exit,
    improving readability:

```cpp
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants