The Tile and Fuse Pass
The tile and fuse pass tiles operations that would not normally fit in the available memory. When an operation is tiled, for efficiency, we also tile(/fuse) with it (in the same tiling loop) operations that it uses directly, and the operations that those operations use, and so on. The operation that triggered the tiling is called the consumer, and the operations it uses are called producers.
At the core of the pass is the function
scf::tileConsumerAndFuseProducersUsingSCF. The function takes an operation
that has an TilingInterface implementation (most of the linalg dialect, and
some tensor operations) to be tiled, and an options object that controls the
function’s execution. After tiling the operation the function recursively fuses
producers, as long as they have an TilingInterface implementation, and the
control function allows it.
Since the scf tiling function fuses producers, we should only tile an operation
when we are sure that none of its consumers needs to be tiled. This suggests a
reverse order traversal. Hence, T&F iterates over all TilingInterface
operations in the function in reverse order of appearance. This guarantees that
when we tile an operation, we have already considered all its users. An
operation is tiled if it hasn’t already been fused, and it does not fit in
memory.
Fuse groups
Later in the pipeline, the passes that do the lowering from linalg to TorqHL use
a few rewrite patterns to do their transformations. In order for those passes to
work, T&F has to make sure that when it tiles an operation that is part
of a pattern, all the other operations that belong to the same pattern are fused
in the same tiling loop (Note that this implies all those operations must
implement TilingInterface). To achieve that, the relevant rewrite patterns are
executed in two modes:
The original rewrite mode (performs the transformation); and
A marking mode (performs no rewrites) that places an attribute (
torq-fuse-group) on the operations that make up the pattern.
The MarkPatternsForSuperTilingPass pass executes the relevant patterns in the
marking mode before the T&F pass.
In most cases a rewrite pattern starts from a principal operation (e.g. a
convolution), walks forward to some output value, and walks backwards to some
set of input values. In the rewrite mode all the operations between those values
are replaced with their TorqHL equivalent. In the marking mode, the function
markFuseGroupBackward is used to mark those operations as belonging to a fuse
group.
In principal, an operation can belong to multiple patterns. Hence, the attribute
we use for marking the groups is an array attribute. Each group is identified by
a UID that is assigned to the principal operation (the one the pattern matching
starts from) at the beginning of the pass. This UID is recorded by the
torq-fuse-group-id attribute.
Memory Footprint Check
Instead of statically approximating the memory footprint, the pass uses a precise approach by executing the actual address assignment pipeline on a cloned subset of the IR.
Extraction: The function
extractOpsForMemoryCheckcreates a temporaryModuleOpcontaining the operation to be checked (along with any fused producers).Simulation: The temporary module is passed to
checkModuleFitsInMemory, which runs theAssignLramAddressespipeline.Verification: If the address assignment succeeds without errors (e.g., no memory overflow), the configuration is deemed valid.
This ensures that the memory footprint check takes into account actual memory layouts and optimizations performed by later passes.
Finding Optimal Tile Sizes
The pass determines optimal tile sizes for all tileable dimensions.
Tiling Order
The consumer operation determines which dimensions are tiled and in what order.
If the consumer is a member of a pattern fuse group, the principal operation of
the fuse group determines which dimensions are tiled and in what order. This is
abstracted by the getTilingIterDomainsOrder function.
Tiling Search (fitTileToMemory)
Finding the best tile size involves a two-phase search:
Shrink Pass: The pass starts with full iteration domain sizes. If the configuration exceeds memory constraints, it iteratively reduces tile sizes along the target dimensions to
1until the simulated module fits in memory.Grow-Back Pass: Starting from the reduced sizes, the pass attempts to maximize the dimensions using a binary search (
searchTileSizeForDim) to find the largest factor that still fits in memory.
Producers Fuse Mode
The command line option torq-tile-and-fuse-producers-fuse-mode controls how
aggressively producers are fused into the consumer’s tiling loop. It supports
the following modes:
max-producers(Default): Prefers fusing more producers over maximizing tile size. The tile size is determined by the biggest size that can accommodate all producers.max-size: Prefers a larger tile size over fusing more producers. The tile size is set to maximize the consumer’s footprint in memory, and producers are fused only if they fit within that tile size.only-patterns: Fuses producers only when strictly required to preserve pattern fuse groups.no-fuse: Disables producer fusion entirely.
Loop Reduction for Memory Simulation
During the memory footprint check, the pass clones the tiled IR and runs a subset of the lowering pipeline to determine if the tiles fit in LRAM. As we are only concerned with the memory usage of the IR at this point, we do not need to compile all the tiles, if they all have the same memory requirements. Therefore, to speed up the compilation time, after the peeling pass, when the tiling for loops have no longer any dynamic shapes, we replace each of the remaining loops with its first iteration.