@@ -1606,61 +1606,46 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
1606
1606
// / Creates an optionally masked TransferWriteOp
1607
1607
// /
1608
1608
// / Generates the following operation:
1609
- // / %res = vector.transfer_write %vectorToStore into %dest
1609
+ // / %res = vector.transfer_write %vecToStore into %dest
1610
1610
// /
1611
- // / If the leading N dimensions of the vector to store do not match
1612
- // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1613
- // / masking is applied to ensure correctness:
1611
+ // / If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1614
1612
// /
1615
- // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1613
+ // / %mask = vector.create_mask(%destShape) : %vecToStoreShape
1616
1614
// / %res = vector.mask %mask {
1617
- // / vector.transfer_write %vectorToStore into %dest
1615
+ // / vector.transfer_write %vecToStore into %dest
1618
1616
// / }
1619
1617
// /
1620
- // / The mask shape is identical to `vectorToStore ` (with the element type ==
1618
+ // / The mask shape is identical to `vecToStore ` (with the element type ==
1621
1619
// / i1), and the mask values are based on the shape of the `dest` tensor.
1622
1620
// /
1623
1621
// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1624
1622
// / is used instead of masking:
1625
1623
// /
1626
- // / %write = vector.transfer_write %vectorToStore into %dest
1624
+ // / %write = vector.transfer_write %vecToStore into %dest
1627
1625
// / in_bounds_flags = (...)
1628
1626
// / %res = vector.transfer_write %input into %dest
1629
1627
// / {in_bounds = in_bounds_flags}
1630
1628
// /
1631
- // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1632
- // / to 0.
1633
- // /
1634
- // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1635
- // / `valueToStore`.
1636
- // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1637
- // / already provided in `vectorToStore`.
1629
+ // / Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1630
+ // / are set to 0.
1638
1631
static Operation *
1639
- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1640
- Value dest,
1641
- ArrayRef<int64_t > inputVecSizesForLeadingDims,
1642
- SmallVector writeIndices = {},
1632
+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vecToStore,
1633
+ Value dest, SmallVector writeIndices = {},
1643
1634
bool useInBoundsInsteadOfMasking = false ) {
1644
1635
1645
1636
ShapedType destType = cast(dest.getType ());
1646
1637
int64_t destRank = destType.getRank ();
1647
1638
auto destShape = destType.getShape ();
1648
1639
1649
- VectorType vecToStoreType = cast(vectorToStore .getType ());
1640
+ VectorType vecToStoreType = cast(vecToStore .getType ());
1650
1641
int64_t vecToStoreRank = vecToStoreType.getRank ();
1651
1642
auto vecToStoreShape = vecToStoreType.getShape ();
1652
1643
1653
1644
// Compute the in_bounds attribute
1654
1645
SmallVector<bool > inBoundsVal (vecToStoreRank, true );
1655
1646
if (useInBoundsInsteadOfMasking) {
1656
- // In this case, assume that all the required vector sizes have been
1657
- // provided.
1658
- assert (inputVecSizesForLeadingDims.size () ==
1659
- static_cast <size_t >(vecToStoreType.getRank ()) &&
1660
- " Insufficient number of input vector sizes!" );
1661
- // Update the inBounds attribute.
1662
1647
for (unsigned i = 0 ; i < destRank; i++)
1663
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims [i]) &&
1648
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape [i]) &&
1664
1649
!ShapedType::isDynamic (destShape[i]);
1665
1650
}
1666
1651
@@ -1676,7 +1661,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1676
1661
// Generate the xfer_write Op
1677
1662
Operation *write =
1678
1663
builder.create (loc,
1679
- /* vector=*/ vectorToStore ,
1664
+ /* vector=*/ vecToStore ,
1680
1665
/* source=*/ dest,
1681
1666
/* indices=*/ writeIndices,
1682
1667
/* inBounds=*/ inBoundsVal);
@@ -1685,46 +1670,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1685
1670
if (useInBoundsInsteadOfMasking)
1686
1671
return write;
1687
1672
1688
- assert (llvm::none_of (
1689
- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1690
- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1691
- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1692
-
1693
- // Check if masking is needed.
1694
- bool needMaskForWrite =
1695
- !llvm::equal (inputVecSizesForLeadingDims,
1696
- destShape.take_front (destRank - vecToStoreRank +
1697
- inputVecSizesForLeadingDims.size ()));
1698
-
1699
- // If masking is needed, generate the mask and mask the operation.
1700
- if (needMaskForWrite) {
1701
- // Get the mask shape + type. Missing mask dimensions are taken from
1702
- // `vectorToStore`.
1703
- SmallVector<int64_t > writeMaskShape;
1704
- writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1705
- inputVecSizesForLeadingDims.end ());
1706
- if (vecToStoreRank >
1707
- static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1708
- writeMaskShape.append (vecToStoreShape.begin () +
1709
- inputVecSizesForLeadingDims.size (),
1710
- vecToStoreShape.end ());
1711
- auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1712
-
1713
- SmallVector destSizes =
1714
- tensor::getMixedSizes (builder, loc, dest);
1715
- SmallVector maskSizes (destSizes.end () - writeMaskShape.size (),
1716
- destSizes.end ());
1717
-
1718
- if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1719
- writeMaskShape))
1720
- return write;
1721
-
1722
- Value maskForWrite = builder.createOrFold (
1723
- loc, writeMaskType, maskSizes);
1724
- write = mlir::vector::maskOperation (builder, write, maskForWrite);
1725
- }
1673
+ // Check if masking is needed. If not, exit.
1674
+ if (llvm::equal (vecToStoreShape, destShape.take_back (vecToStoreRank)))
1675
+ return write;
1676
+
1677
+ // Compute the mask and mask the write Op.
1678
+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1679
+
1680
+ SmallVector destSizes =
1681
+ tensor::getMixedSizes (builder, loc, dest);
1682
+ SmallVector maskSizes (destSizes.end () - vecToStoreRank,
1683
+ destSizes.end ());
1684
+
1685
+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1686
+ vecToStoreShape))
1687
+ return write;
1726
1688
1727
- return write;
1689
+ Value maskForWrite =
1690
+ builder.createOrFold (loc, writeMaskType, maskSizes);
1691
+ return mlir::vector::maskOperation (builder, write, maskForWrite);
1728
1692
}
1729
1693
1730
1694
// / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1824,10 +1788,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1824
1788
Value dest = rewriter.create (
1825
1789
loc, reifiedReturnShapes[0 ],
1826
1790
transposeOp.getResult ().getType ().getElementType ());
1827
- Operation *write = createWriteOrMaskedWrite (
1828
- rewriter, loc, transposeOp.getResult (), dest,
1829
- /* inputVecSizesForLeadingDims= */ inputVectorSizes, /* writeIndices=*/ {},
1830
- /* useInBoundsInsteadOfMasking=*/ false );
1791
+ Operation *write =
1792
+ createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (), dest,
1793
+ /* writeIndices=*/ {},
1794
+ /* useInBoundsInsteadOfMasking=*/ false );
1831
1795
newResults.push_back (write->getResult (0 ));
1832
1796
return success ();
1833
1797
}
@@ -1965,7 +1929,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1965
1929
shapeCastOp.getResult ().getType ().getElementType ());
1966
1930
Operation *write = createWriteOrMaskedWrite (
1967
1931
rewriter, loc, shapeCastOp.getResult (), dest,
1968
- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1969
1932
/* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
1970
1933
newResults.push_back (write->getResult (0 ));
1971
1934
return success ();
@@ -1998,10 +1961,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1998
1961
// Create Xfer write Op
1999
1962
Value dest = rewriter.create (
2000
1963
loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
2001
- Operation *write = createWriteOrMaskedWrite (
2002
- rewriter, loc, maskedRead, dest,
2003
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes, {},
2004
- /* useInBoundsInsteadOfMasking=*/ false );
1964
+ Operation *write =
1965
+ createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest, {},
1966
+ /* useInBoundsInsteadOfMasking=*/ false );
2005
1967
newResults.push_back (write->getResult (0 ));
2006
1968
return success ();
2007
1969
}
@@ -3057,8 +3019,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3057
3019
// Create write
3058
3020
auto writeIndices =
3059
3021
getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3060
- Operation *write = createWriteOrMaskedWrite (
3061
- rewriter, loc, read, sliceOp.getDest (), vecType. getShape (), writeIndices);
3022
+ Operation *write = createWriteOrMaskedWrite (rewriter, loc, read,
3023
+ sliceOp.getDest (), writeIndices);
3062
3024
3063
3025
// 4. Finalize
3064
3026
newResults.push_back (write->getResult (0 ));
0 commit comments