The Super Tiling Pass

The super tiling pass tiles operations that would not normally fit in the available memory. For efficiency, when possible, if one operation is tiled, we also tile with it (in the same tiling loop) all the operations that directly feed it, and the operations that directly feed them, and so on. The feeding operations are called producers, and tiling them in the same loop as the original operation is called fusion.

At the core of the pass is the function scf::tileConsumerAndFuseProducersUsingSCF. The function takes an operation that has an mlir::TilingInterface implementation and needs to be tiled, and an options object that controls the function’s execution. After tiling the operation the function recursively fuses producers, as long as they have an mlir::TilingInterface implementation.

Since the scf tiling function fuses producers, we should only tile an operation when we are sure that none of its consumers needs to be tiled. This suggests a post-order traversal. Hence, super tiling does a DFS traversal over the operations data flow DAG, and processes the operations in post-order. An operation is tiled if it hasn’t already been fused, and it does not fit in memory (more on that later).

Attention

TODO Currently we first collect all the TilingInterface operations in a post-order, and then we process them. We can’t do the processing inline as we will be modifying the graph while waking it. I think mlir has pattern drivers, such as mlir::walkAndApplyPatterns, that can handle the walk for us, but I’m not sure they do the walk in the right order. The documentation does say the walk is post-order, but it does not say over what graph (data-flow?).

Attention

TODO What if an operation feeds multiple consumers? Once an operation was fused, it will not be considered for tiling.

  [a]
  / \
[b] [c]

If a and b implement mlir::TilingInterface, and both need to be tiled to fit in memory, we will tile b first, and fuse a with it. a still needs to be tiled in order to drive c, but currently we will not tile a again, which will cause the compiler to fail later due to memory overflow. Note that if c also implements mlir::TilingInterface and is too big to fit in memory, we will tile c and fuse a with it, avoiding the overflow. The problem is only when a derives an operation that is not going to be tiled.

Fuse groups

Later in the pipeline, the passes that do the lowering from linalg to TorqHL use a few rewrite patterns to do their transformations. In order for those passes to work, super tiling has to make sure that when it tiles an operation that is part of a pattern, all the other operations that belong to the same pattern are fused in the same tiling loop (Note that this implies all those operations must implement mlir::TilingInterface). To achieve that, we have modified the relevant rewrite patterns to operate in two modes: the original rewrite mode; and a new marking mode where no rewrites are done, except for placing an attribute (torq-fuse-group) on the operations that make up the pattern. The MarkPatternsForSuperTilingPass pass executes the patterns in the marking mode before the super tiling pass.

In most cases a rewrite pattern starts from a principal operation (e.g. a convolution), walks forward to some output value, and walks backwards to some set of input values. In the rewrite mode all the operations between those values are replaced with their TorqHL counterpart. In the marking mode we use the function markFuseGroupBackward to mark those operations as belonging to a fuse group.

In principal, an operation can belong to multiple patterns. Hence, the attribute we use for marking the groups is an array attribute. Each group is identified by a UID that is assigned to the principal operation (the one the pattern matching starts from) at the beginning of the pass. This UID is recorded by the torq-fuse-group-id attribute.

Memory footprint check

Note that there’s an inherent problem in approximating memory footprint at this early stage of the pipeline: the graph is going to be optimized in later stages of the pipeline, which entails a smaller footprint.

We approximate the memory footprint of an operation in the function getOperationDataSize. For operations that do not belong to a fuse group we generally sum the data size of all the tensor operands (ultimately using mlir::syna::getShapeTypeDataSize). This is done in an mlir::TypeSwitch which allows specialization for different operations.

For operations that belong to a fuse group we return a non-zero value only when the operation has no consumers from the same fuse group (and 0 otherwise). This ensures a fuse group operation is tiled only when it is the bottom most operation of the group, and all the fuse group operations will be fused to its tiling loop. The value we return for that operation is an approximation of the TorqHL operation that will replace the whole group. To compute that we walk backwards and look for the inputs to the fuse group, and we sum their data sizes.

Finding a good tiling factor

We currently always tile the second dimension of an operation, unless it only has one dimension in which case we tile that dimension.

The initial tiling factor is calculated by dividing the approximated memory footprint of the operation by the available memory. We assume operations that belong to a fuse group will be executed on the NSS, so we use TorqHw::get().getAvailableMemoryForTiling() to get the available memory for them, and for other operations we use 10k.

After calling scf::tileConsumerAndFuseProducersUsingSCF we know which operations fused into the tiling loop. We calculate a memory footprint for each one of them, take the maximum of all footprints, divide it by the tiling factor, and use that as the approximation of the tile footprint. If this approximation divided by the factor is still smaller than the available memory, we are done processing the operation. Otherwise we call scf::tileConsumerAndFuseProducersUsingSCF again, this time with a factor based on the new memory footprint.

Note that if we tried to approximate the memory footprint of the tiled operation by inspecting the resulted loop, we would not be able to leverage the fuse group information. Hence, we collect the footprints from the original operations and divide it by the factor, or by the available memory to calculate the value that we need.

Attention

TODO Taking the maximum of all the operations (tiled + fused), assume they are in applied in a sequence. This should be fixed to take into account the real hierarchy.