@@ -1605,63 +1605,49 @@ static bool isMaskTriviallyFoldable(SmallVector &maskSizes,
1605
1605
// / Creates an optionally masked TransferWriteOp
1606
1606
// /
1607
1607
// / Generates the following operation:
1608
- // / %res = vector.transfer_write %vectorToStore into %dest
1608
+ // / %res = vector.transfer_write %vecToStore into %dest
1609
1609
// /
1610
- // / If the leading N dimensions of the vector to store do not match
1611
- // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1612
- // / masking is applied to ensure correctness:
1610
+ // / If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1613
1611
// /
1614
- // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1612
+ // / %mask = vector.create_mask(%destShape) : %vecToStoreShape
1615
1613
// / %res = vector.mask %mask {
1616
- // / vector.transfer_write %vectorToStore into %dest
1614
+ // / vector.transfer_write %vecToStore into %dest
1617
1615
// / }
1618
1616
// /
1619
- // / The mask shape is identical to `vectorToStore ` (with the element type ==
1617
+ // / The mask shape is identical to `vecToStore ` (with the element type ==
1620
1618
// / i1), and the mask values are based on the shape of the `dest` tensor.
1621
1619
// /
1622
1620
// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1623
1621
// / is used instead of masking:
1624
1622
// /
1625
- // / %write = vector.transfer_write %vectorToStore into %dest
1623
+ // / %write = vector.transfer_write %vecToStore into %dest
1626
1624
// / in_bounds_flags = (...)
1627
1625
// / %res = vector.transfer_write %input into %dest
1628
1626
// / {in_bounds = in_bounds_flags}
1629
1627
// /
1630
- // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1631
- // / to 0.
1632
- // /
1633
- // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1634
- // / `valueToStore`.
1635
- // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1636
- // / already provided in `vectorToStore`.
1628
+ // / Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1629
+ // / are set to 0.
1637
1630
static Operation *
1638
- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1639
- Value dest,
1640
- ArrayRef<int64_t > inputVecSizesForLeadingDims,
1641
- SmallVector writeIndices = {},
1631
+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vecToStore,
1632
+ Value dest, SmallVector writeIndices = {},
1642
1633
bool useInBoundsInsteadOfMasking = false ) {
1643
1634
1644
1635
ShapedType destType = cast(dest.getType ());
1645
1636
int64_t destRank = destType.getRank ();
1646
1637
auto destShape = destType.getShape ();
1647
1638
1648
- VectorType vecToStoreType = cast(vectorToStore .getType ());
1639
+ VectorType vecToStoreType = cast(vecToStore .getType ());
1649
1640
int64_t vecToStoreRank = vecToStoreType.getRank ();
1650
1641
auto vecToStoreShape = vecToStoreType.getShape ();
1651
1642
1652
1643
// Compute the in_bounds attribute
1653
1644
SmallVector<bool > inBoundsVal (vecToStoreRank, true );
1654
1645
if (useInBoundsInsteadOfMasking) {
1655
- // In this case, assume that all the required vector sizes have been
1656
- // provided.
1657
- assert (inputVecSizesForLeadingDims.size () ==
1658
- static_cast <size_t >(vecToStoreType.getRank ()) &&
1659
- " Insufficient number of input vector sizes!" );
1660
1646
// Update the inBounds attribute.
1661
1647
// FIXME: This computation is too weak - it ignores the write indices.
1662
1648
for (unsigned i = 0 ; i < vecToStoreRank; i++)
1663
1649
inBoundsVal[i] =
1664
- (destShape[i] >= inputVecSizesForLeadingDims [i]) &&
1650
+ (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape [i]) &&
1665
1651
!ShapedType::isDynamic (destShape[destRank - vecToStoreRank + i]);
1666
1652
}
1667
1653
@@ -1677,7 +1663,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1677
1663
// Generate the xfer_write Op
1678
1664
Operation *write =
1679
1665
builder.create (loc,
1680
- /* vector=*/ vectorToStore ,
1666
+ /* vector=*/ vecToStore ,
1681
1667
/* source=*/ dest,
1682
1668
/* indices=*/ writeIndices,
1683
1669
/* inBounds=*/ inBoundsVal);
@@ -1686,46 +1672,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1686
1672
if (useInBoundsInsteadOfMasking)
1687
1673
return write;
1688
1674
1689
- assert (llvm::none_of (
1690
- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1691
- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1692
- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1693
-
1694
- // Check if masking is needed.
1695
- bool needMaskForWrite =
1696
- !llvm::equal (inputVecSizesForLeadingDims,
1697
- destShape.take_front (destRank - vecToStoreRank +
1698
- inputVecSizesForLeadingDims.size ()));
1699
-
1700
- // If masking is needed, generate the mask and mask the operation.
1701
- if (needMaskForWrite) {
1702
- // Get the mask shape + type. Missing mask dimensions are taken from
1703
- // `vectorToStore`.
1704
- SmallVector<int64_t > writeMaskShape;
1705
- writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1706
- inputVecSizesForLeadingDims.end ());
1707
- if (vecToStoreRank >
1708
- static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1709
- writeMaskShape.append (vecToStoreShape.begin () +
1710
- inputVecSizesForLeadingDims.size (),
1711
- vecToStoreShape.end ());
1712
- auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1713
-
1714
- SmallVector destSizes =
1715
- tensor::getMixedSizes (builder, loc, dest);
1716
- SmallVector maskSizes (destSizes.end () - writeMaskShape.size (),
1717
- destSizes.end ());
1718
-
1719
- if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1720
- writeMaskShape))
1721
- return write;
1722
-
1723
- Value maskForWrite = builder.createOrFold (
1724
- loc, writeMaskType, maskSizes);
1725
- write = mlir::vector::maskOperation (builder, write, maskForWrite);
1726
- }
1675
+ // Check if masking is needed. If not, exit.
1676
+ if (llvm::equal (vecToStoreShape, destShape.take_back (vecToStoreRank)))
1677
+ return write;
1678
+
1679
+ // Compute the mask and mask the write Op.
1680
+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1681
+
1682
+ SmallVector destSizes =
1683
+ tensor::getMixedSizes (builder, loc, dest);
1684
+ SmallVector maskSizes (destSizes.end () - vecToStoreRank,
1685
+ destSizes.end ());
1686
+
1687
+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1688
+ vecToStoreShape))
1689
+ return write;
1727
1690
1728
- return write;
1691
+ Value maskForWrite =
1692
+ builder.createOrFold (loc, writeMaskType, maskSizes);
1693
+ return mlir::vector::maskOperation (builder, write, maskForWrite);
1729
1694
}
1730
1695
1731
1696
// / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1825,9 +1790,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1825
1790
Value dest = rewriter.create (
1826
1791
loc, reifiedReturnShapes[0 ],
1827
1792
transposeOp.getResult ().getType ().getElementType ());
1828
- Operation *write = createWriteOrMaskedWrite (
1829
- rewriter, loc, transposeOp.getResult (), dest,
1830
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes);
1793
+ Operation *write =
1794
+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (), dest);
1831
1795
newResults.push_back (write->getResult (0 ));
1832
1796
return success ();
1833
1797
}
@@ -1965,7 +1929,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1965
1929
shapeCastOp.getResult ().getType ().getElementType ());
1966
1930
Operation *write = createWriteOrMaskedWrite (
1967
1931
rewriter, loc, shapeCastOp.getResult (), dest,
1968
- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1969
1932
/* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
1970
1933
newResults.push_back (write->getResult (0 ));
1971
1934
return success ();
@@ -1998,9 +1961,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1998
1961
// Create Xfer write Op
1999
1962
Value dest = rewriter.create (
2000
1963
loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
2001
- Operation *write = createWriteOrMaskedWrite (
2002
- rewriter, loc, maskedRead, dest,
2003
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes);
1964
+ Operation *write = createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest);
2004
1965
newResults.push_back (write->getResult (0 ));
2005
1966
return success ();
2006
1967
}
@@ -3040,9 +3001,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3040
3001
// Create write
3041
3002
auto writeIndices =
3042
3003
getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3043
- Operation *write = createWriteOrMaskedWrite (
3044
- rewriter, loc, read, sliceOp.getDest (), vecType. getShape (), writeIndices ,
3045
- /* useInBoundsInsteadOfMasking= */ inputVectorSizes.empty ());
3004
+ Operation *write =
3005
+ createWriteOrMaskedWrite ( rewriter, loc, read, sliceOp.getDest (),
3006
+ writeIndices, inputVectorSizes.empty ());
3046
3007
3047
3008
// 4. Finalize
3048
3009
newResults.push_back (write->getResult (0 ));
0 commit comments