@@ -1606,63 +1606,49 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
1606
1606
// / Creates an optionally masked TransferWriteOp
1607
1607
// /
1608
1608
// / Generates the following operation:
1609
- // / %res = vector.transfer_write %vectorToStore into %dest
1609
+ // / %res = vector.transfer_write %vecToStore into %dest
1610
1610
// /
1611
- // / If the leading N dimensions of the vector to store do not match
1612
- // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1613
- // / masking is applied to ensure correctness:
1611
+ // / If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1614
1612
// /
1615
- // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1613
+ // / %mask = vector.create_mask(%destShape) : %vecToStoreShape
1616
1614
// / %res = vector.mask %mask {
1617
- // / vector.transfer_write %vectorToStore into %dest
1615
+ // / vector.transfer_write %vecToStore into %dest
1618
1616
// / }
1619
1617
// /
1620
- // / The mask shape is identical to `vectorToStore ` (with the element type ==
1618
+ // / The mask shape is identical to `vecToStore ` (with the element type ==
1621
1619
// / i1), and the mask values are based on the shape of the `dest` tensor.
1622
1620
// /
1623
1621
// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1624
1622
// / is used instead of masking:
1625
1623
// /
1626
- // / %write = vector.transfer_write %vectorToStore into %dest
1624
+ // / %write = vector.transfer_write %vecToStore into %dest
1627
1625
// / in_bounds_flags = (...)
1628
1626
// / %res = vector.transfer_write %input into %dest
1629
1627
// / {in_bounds = in_bounds_flags}
1630
1628
// /
1631
- // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1632
- // / to 0.
1633
- // /
1634
- // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1635
- // / `valueToStore`.
1636
- // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1637
- // / already provided in `vectorToStore`.
1629
+ // / Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1630
+ // / are set to 0.
1638
1631
static Operation *
1639
- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1640
- Value dest,
1641
- ArrayRef<int64_t > inputVecSizesForLeadingDims,
1642
- SmallVector writeIndices = {},
1632
+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vecToStore,
1633
+ Value dest, SmallVector writeIndices = {},
1643
1634
bool useInBoundsInsteadOfMasking = false ) {
1644
1635
1645
1636
ShapedType destType = cast(dest.getType ());
1646
1637
int64_t destRank = destType.getRank ();
1647
1638
auto destShape = destType.getShape ();
1648
1639
1649
- VectorType vecToStoreType = cast(vectorToStore .getType ());
1640
+ VectorType vecToStoreType = cast(vecToStore .getType ());
1650
1641
int64_t vecToStoreRank = vecToStoreType.getRank ();
1651
1642
auto vecToStoreShape = vecToStoreType.getShape ();
1652
1643
1653
1644
// Compute the in_bounds attribute
1654
1645
SmallVector<bool > inBoundsVal (vecToStoreRank, true );
1655
1646
if (useInBoundsInsteadOfMasking) {
1656
- // In this case, assume that all the required vector sizes have been
1657
- // provided.
1658
- assert (inputVecSizesForLeadingDims.size () ==
1659
- static_cast <size_t >(vecToStoreType.getRank ()) &&
1660
- " Insufficient number of input vector sizes!" );
1661
1647
// Update the inBounds attribute.
1662
1648
// FIXME: This computation is too weak - it ignores the write indices.
1663
1649
for (unsigned i = 0 ; i < vecToStoreRank; i++)
1664
1650
inBoundsVal[i] =
1665
- (destShape[i] >= inputVecSizesForLeadingDims [i]) &&
1651
+ (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape [i]) &&
1666
1652
!ShapedType::isDynamic (destShape[destRank - vecToStoreRank + i]);
1667
1653
}
1668
1654
@@ -1678,7 +1664,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1678
1664
// Generate the xfer_write Op
1679
1665
Operation *write =
1680
1666
builder.create (loc,
1681
- /* vector=*/ vectorToStore ,
1667
+ /* vector=*/ vecToStore ,
1682
1668
/* source=*/ dest,
1683
1669
/* indices=*/ writeIndices,
1684
1670
/* inBounds=*/ inBoundsVal);
@@ -1687,46 +1673,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1687
1673
if (useInBoundsInsteadOfMasking)
1688
1674
return write;
1689
1675
1690
- assert (llvm::none_of (
1691
- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1692
- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1693
- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1694
-
1695
- // Check if masking is needed.
1696
- bool needMaskForWrite =
1697
- !llvm::equal (inputVecSizesForLeadingDims,
1698
- destShape.take_front (destRank - vecToStoreRank +
1699
- inputVecSizesForLeadingDims.size ()));
1700
-
1701
- // If masking is needed, generate the mask and mask the operation.
1702
- if (needMaskForWrite) {
1703
- // Get the mask shape + type. Missing mask dimensions are taken from
1704
- // `vectorToStore`.
1705
- SmallVector<int64_t > writeMaskShape;
1706
- writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1707
- inputVecSizesForLeadingDims.end ());
1708
- if (vecToStoreRank >
1709
- static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1710
- writeMaskShape.append (vecToStoreShape.begin () +
1711
- inputVecSizesForLeadingDims.size (),
1712
- vecToStoreShape.end ());
1713
- auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1714
-
1715
- SmallVector destSizes =
1716
- tensor::getMixedSizes (builder, loc, dest);
1717
- SmallVector maskSizes (destSizes.end () - writeMaskShape.size (),
1718
- destSizes.end ());
1719
-
1720
- if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1721
- writeMaskShape))
1722
- return write;
1723
-
1724
- Value maskForWrite = builder.createOrFold (
1725
- loc, writeMaskType, maskSizes);
1726
- write = mlir::vector::maskOperation (builder, write, maskForWrite);
1727
- }
1676
+ // Check if masking is needed. If not, exit.
1677
+ if (llvm::equal (vecToStoreShape, destShape.take_back (vecToStoreRank)))
1678
+ return write;
1679
+
1680
+ // Compute the mask and mask the write Op.
1681
+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1682
+
1683
+ SmallVector destSizes =
1684
+ tensor::getMixedSizes (builder, loc, dest);
1685
+ SmallVector maskSizes (destSizes.end () - vecToStoreRank,
1686
+ destSizes.end ());
1687
+
1688
+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1689
+ vecToStoreShape))
1690
+ return write;
1728
1691
1729
- return write;
1692
+ Value maskForWrite =
1693
+ builder.createOrFold (loc, writeMaskType, maskSizes);
1694
+ return mlir::vector::maskOperation (builder, write, maskForWrite);
1730
1695
}
1731
1696
1732
1697
// / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1826,9 +1791,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1826
1791
Value dest = rewriter.create (
1827
1792
loc, reifiedReturnShapes[0 ],
1828
1793
transposeOp.getResult ().getType ().getElementType ());
1829
- Operation *write = createWriteOrMaskedWrite (
1830
- rewriter, loc, transposeOp.getResult (), dest,
1831
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes);
1794
+ Operation *write =
1795
+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (), dest);
1832
1796
newResults.push_back (write->getResult (0 ));
1833
1797
return success ();
1834
1798
}
@@ -1966,7 +1930,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1966
1930
shapeCastOp.getResult ().getType ().getElementType ());
1967
1931
Operation *write = createWriteOrMaskedWrite (
1968
1932
rewriter, loc, shapeCastOp.getResult (), dest,
1969
- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1970
1933
/* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
1971
1934
newResults.push_back (write->getResult (0 ));
1972
1935
return success ();
@@ -1999,9 +1962,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1999
1962
// Create Xfer write Op
2000
1963
Value dest = rewriter.create (
2001
1964
loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
2002
- Operation *write = createWriteOrMaskedWrite (
2003
- rewriter, loc, maskedRead, dest,
2004
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes);
1965
+ Operation *write = createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest);
2005
1966
newResults.push_back (write->getResult (0 ));
2006
1967
return success ();
2007
1968
}
@@ -3043,9 +3004,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3043
3004
// Create write
3044
3005
auto writeIndices =
3045
3006
getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3046
- Operation *write = createWriteOrMaskedWrite (
3047
- rewriter, loc, read, sliceOp.getDest (), vecType. getShape (), writeIndices ,
3048
- /* useInBoundsInsteadOfMasking= */ inputVectorSizes.empty ());
3007
+ Operation *write =
3008
+ createWriteOrMaskedWrite ( rewriter, loc, read, sliceOp.getDest (),
3009
+ writeIndices, inputVectorSizes.empty ());
3049
3010
3050
3011
// 4. Finalize
3051
3012
newResults.push_back (write->getResult (0 ));
0 commit comments