8.1.1 Modularity: Adding IR Node Types
We observe that there are no concrete definition classes provided by traitExpressions. Provid- ing meaningful data types is the responsibility of other traits that implement the interfaces defined previously (Baseand its descendents).
TraitBaseExpforms the root of the implementation hierarchy and installs atomic expres- sions as the representation of staged values by definingRep[T] = Exp[T]:
trait BaseExp extends Base with Expressions { type Rep[T] = Exp[T]
}
For each interface trait, there is one corresponding core implementation trait. Shown be- low, we have traitsArithExpandIfThenElseExpas the running example. Both traits define one definition class for each operation defined byArithandIfThenElse, respectively, and implement the corresponding interface methods to create instances of those classes.
trait ArithExp extends BaseExp with Arith {
case class Plus(x: Exp[Double], y: Exp[Double]) extends Def[Double] case class Times(x: Exp[Double], y: Exp[Double]) extends Def[Double] def infix_+(x: Rep[Double], y: Rep[Double]) = Plus(x, y)
def infix_*(x: Rep[Double], y: Rep[Double]) = Times(x, y)
... }
trait IfThenElseExp extends BaseExp with IfThenElse {
case class IfThenElse(c: Exp[Boolean], a: Block[T], b: Block[T]) extends Def[T] def __ifThenElse[T](c: Rep[Boolean], a: =>Rep[T], b: =>Rep[T]): Rep[T] =
IfThenElse(c, reifyBlock(a), reifyBlock(b)) }
The framework ensures that code that contains staging operations will always be executed within the dynamic scope of at least one invocation ofreifyBlock, which returns a block object and takes as call-by-name argument the present-stage expression that will compute the staged block result. Block objects can be part of definitions, e.g. for loops or conditionals.
Since all operations in interface traits such asArithreturnReptypes, equatingRep[T] andExp[T]in traitBaseExpmeans that conversion to symbols will take place already within those methods. This fact is important because it establishes our correspondence between the evaluation order of the program generator and the evaluation order of the generated program: at the point where the generator callstoAtom, the composite definition is turned into an atomic value viareflectStm, i.e. its evaluation will be recorded now and played back later in the same relative order with respect to others within the closestreifyBlockinvocation.
8.2 Enabling Analysis and Transformation
8.2.1 Modularity: Adding Traversal Passes
All that is needed to define a generic in-order traversal is a way to access all blocks immediately contained in a definition:
def blocks(x: Any): List[Block[Any]]
For example, applyingblocksto anIfThenElsenode will return the then and else blocks. Since definitions are case classes, this method is easy to implement by using theProduct interface that all case classes implement.
The basic structural in-order traversal is then defined like this:
trait ForwardTraversal { val IR: Expressions import IR._
def traverseBlock[T](b: Block[T]): Unit = b.stms.foreach(traverseStm) def traverseStm[T](s: Stm[T]): Unit = blocks(s).foreach(traverseBlock)
}
Custom traversals can be implemented in a modular way by extending theForwardTraversal trait:
trait MyTraversalBase extends ForwardTraversal { val IR: BaseExp
import IR._
override def traverseStm[T](s: Stm[T]) = s match {
// custom base case or delegate to super
case _ => super.traverseStm(s)
} }
trait MyTraversalArith extends MyTraversalBase { val IR: ArithExp
import IR._
override def traverseStm[T](s: Stm[T]) = s match { case Plus(x,y) => ... // handle specific nodes case _ => super.traverseStm(s)
} }
For each unit of functionality such asArithorIfThenElsethe traversal actions can be defined separately asMyTraversalArithandMyTraversalIfThenElse.
Finally, we can use our traversal as follows:
trait Prog extends Arith {
def main = ... // program code here
}
val impl = new Prog with ArithExp val res = impl.reifyBlock(impl.main)
val inspect = MyTraversalArith { val IR: impl.type = impl }
8.2. Enabling Analysis and Transformation
8.2.2 Solving the “Expression Problem”
In essence, traversals confront us with the classic “expression problem” of independently extending a data model with new data variants and new operations [152]. There are many solutions to this problem but most of them are rather heavyweight. More lightweight im- plementations are possible in languages that support multi-methods, i.e. dispatch method calls dynamically based on the actual types of all the arguments. We can achieve essentially the same using pattern matching and mixin composition, making use of the fact that com- posing traits is subject to linearization [100]. We package each set of specific traversal rules into its own trait, e.g.MyTraversalAriththat inherits fromMyTraversalBaseand overrides traverseStm. When the arguments do not match the rewriting pattern, the overridden method will invoke the “parent” implementation usingsuper. When several such traits are combined, the super calls will traverse the overridden method implementations according to the lin- earization order of their containing traits. The use of pattern matching and super calls is similar to earlier work on extensible algebraic data types with defaults [160], which supported linear extensions but not composition of independent extensions.
Implementing multi-methods in a statically typed setting usually poses three problems: separate type checking/compilation, ensuring non-ambiguity and ensuring exhaustiveness. The described encoding supports separate type-checking and compilation in as far as traits do. Ambiguity is ruled out by always following the linearization order and the first-match semantics of pattern matching. Exhaustiveness is ensured at the type level by requiring a default implementation, although no guarantees can be made that the default will not choose to throw an exception at runtime. In the particular case of traversals, the default is always safe and will just continue the structural traversal.
8.2.3 Generating Code
Code generation is just a traversal pass that prints code. Compiling and executing code can use the same mechanism as described in Section 6.2.6.
8.2.4 Modularity: Adding Transformations
Transformations work very similar to traversals. One option is to traverse and transform an existing program more or less in place, not actually modifying data but attaching new Defs to existing Syms:
trait SimpleTransformer { val IR: Expressions import IR._
def transformBlock[T](b: Block[T]): Block[T] =
Block(b.stms.flatMap(transformStm), transformExp(b.res))
def transformStm[T](s: Stm[T]): List[Stm] =
List(Stm(s.lhs, transformDef(s.rhs))) // preserve existing symbol s
def transformDef[T](d: Def[T]): Def[T] // default: use reflection // to map over case classes
}
An implementation is straightforward:
trait MySimpleTransformer extends SimpleTransformer { val IR: IfThenElseExp
import IR._
// override transformDef for each Def subclass
def transformDef[T](d: Def[T]): Def[T] = d match { case IfThenElse(c,a,b) =>
IfThenElse(transformExp(c), transformBlock(a), transformBlock(b))
case _ => super.transformDef(d)
} }
8.2.5 Transformation by Iterated Staging
Another option that is more principled and in line with the idea of making compiler transforms programmable through the use of staging is to traverse the old program and create a new program. Effectively we are implementing an IR interpreter that executes staging commands, which greatly simplifies the implementation of the transform and removes the need for low- level IR manipulation.
In the implementation, we will create new symbols instead of reusing existing ones so we need to maintain a substitution that maps old to new Syms. The core implementation is given below:
trait ForwardTransformer extends ForwardTraversal { val IR: Expressions
import IR._
var subst: Map[Exp[_],Exp[_]]
def transformExp[T](s: Exp[T]): Exp[T] = ... // lookup s in subst def transformDef[T](d: Def[T]): Exp[T] // default
def transformStm[T](s: Stm[T]): Exp[T] = {
val e = transformDef(s.rhs); subst += (s.sym -> e); e
}
override def traverseStm[T](s: Stm[T]): Unit = {
transformStm(s) }
def reflectBlock[T](b: Block[T]): Exp[T] = withSubstScope {
traverseBlock(b); transformExp(b.res) }
def transformBlock[T](b: Block[T]): Block[T] = {
reifyBlock(reflectBlock(b)) }
}
Here is a simple identity transformer implementation for conditionals and array construc- tion: