Sunday, August 03, 2008

Expression Templates Demystified: Part 2

In this article, we would see how the use of templates can simplify and generalize the Expression Functors code, and improve performance by eliminating virtual functions.

In the example from the previous segment of this article, the Expression abstract base class represented the family of all types that could be used to represent different kinds of expressions. Two specific classes, Variable and Constant, represented the leafs of an expression tree - numeric literals and single variable names. These could be combined to form other expressions - generally represented by the ComplexExpression class. All these classes implemented the interface defined by the Expression abstract base class.

We now plan to switch from dynamic polymorphism to template driven static polymorphism. For example, given classes with a common interface but without a common base class, like the ones below:

Show line numbers
 struct Foo {
void bar() {
// implementation
}
};

struct Foo2 {
void bar() {
// implementation
}
};

we can wrap them using a class template that has a similar interface as these:

Show line numbers
 template<class F>
struct AllFoo {
AllFoo(F& f) : f_(f) {}

void bar() {
f_.bar();
}
};

This enforces that AllFoo can only be instantiated with such classes which have the bar() method in their public interface. Thus we write the following alternate definitions:

Show line numbers
 template<class E>
struct Expr {
Expr(E& e) : e_(e) {
}

double operator() (double d) {
return e_(d);
}

E e_;
};

Here Expr is a class that can encapsulate all expression objects, like the ones of type Constant or Variable below.

Show line numbers
 struct Constant {
Constant(double d) : d_(d) { }
Constant(int d) : d_(d) { }
double operator() (double) {
return d_;
}

double d_;
};


struct Variable {
double operator() (double d) {
return d;
}
};

Finally, a class to represent non-terminal, complex expressions:

Show line numbers
 template<class E1, class E2, class Op>
struct ComplexExpr {
ComplexExpr(Expr<E1> l, Expr<E2> r) : l_(l), r_(r) {
}

double operator() (double d) {
return Op::apply(l_(d), r_(d));
}

Expr<E1> l_;
Expr<E2> r_;
};

We have just carried over the definition of the ComplexExpression class from the previous article and made it into a template class. The operator classes like Add, Subtract, Multiply and Divide should continue to work unchanged. Guess what, we just have to define the operator overloads (for +, -, * and /) and we will be done with defining our framework of algebraic expression.

Now, we would want to write expressions such as:

Variable x;
cout << (x+3)(2); // prints 5
cout << ((x*x+3)*(x+3))(2); // prints 35

In the above, x is a Variable, 3 should generate a Constant, x + 3 should result in a ComplexExpr<Variable, Constant>. Further, x*x is a ComplexExpr<Variable, Variable> and (x*x+3)*(x+3) is a ComplexExpr<ComplexExpr< ComplexExpr<Variable, Variable>, Constant>, ComplexExpr<Variable, Constant> >. One can trace the template parameter types as sub-trees of the class template ComplexExpression.

Clearly, to write an expression like x+3, we need an overload of operator+ between a Variable and an integer. To write an expression like (x*x+3)*(x+3) - we need an operator* between two ComplexExpressions. To write something like, x+Constant(3), we need an overload of operator+ between Variable and Constant. Now, just as we defined the basic operators on Expression in the previous article, we could do the same on the Expr class template here, instead of overloading each operator for different combinations of types. For example:

Show line numbers
 template<class E1, class E2>
Expr<ComplexExpr<Expr<E1>, Expr<E2>, Add> > operator+ (E1 e1, E2 e2) {
typedef ComplexExpr<Expr<E1>, Expr<E2>, Add> ExprType;
return Expr<ExprType>( ExprType(Expr<E1>(e1), Expr<E2>(e2)) );
}

The other operators are not very difficult to write, following the above code. However, turns out that while these do take care of operators between complex expressions as well as Variables and Constants, they cannot handle real number or integer literals. To handle real number literals, we define the following overload pairs for each operator.

Show line numbers
 template<class E1>
Expr<ComplexExpr<Expr<E1>, Expr<Constant>, Multiply> > operator* (E1 e1, double d) {
typedef ComplexExpr<Expr<E1>, Expr<Constant>, Multiply> ExprType;
return Expr<ExprType>( ExprType(Expr<E1>(e1), Expr<Constant>(Constant(d))) );
}

template<class E1>
Expr<ComplexExpr<Expr<Constant>, Expr<E1>, Multiply> > operator* (double d, E1 e1) {
typedef ComplexExpr<Expr<Constant>, Expr<E1>, Multiply> ExprType;
return Expr<ExprType>( ExprType(Expr<Constant>(Constant(d)), Expr<E1>(e1)) );
}

These take care of all cases - except for a small glitch which we can pepare to live with. We can write such expressions as:

Variable x;
cout << (x+3.0)(2); // prints 5
cout << ((x*x+3.0)*(x+3.0))(2); // prints 35


but not ones like:

Variable x;
cout << (x+3)(2); // prints 5
cout << ((x*x+3)*(x+3))(2); // prints 35


If you have to be able to do this, add additional overloads like:

Show line numbers
 template<class E1>
Expr<ComplexExpr<Expr<E1>, Expr<Constant>, Multiply> > operator* (E1 e1, int d) {
typedef ComplexExpr<Expr<E1>, Expr<Constant>, Multiply> ExprType;
return Expr<ExprType>( ExprType(Expr<E1>(e1), Expr<Constant>(Constant(d))) );
}

template<class E1>
Expr<ComplexExpr<Expr<Constant>, Expr<E1>, Multiply> > operator* (int d, E1 e1) {
typedef ComplexExpr<Expr<Constant>, Expr<E1>, Multiply> ExprType;
return Expr<ExprType>( ExprType(Expr<Constant>(Constant(d)), Expr<E1>(e1)) );
}
As it should be already apparent - there are no objects allocated on free store, no reference counted proxy wrappers, and a fair bit of compile-time type computation and call dispatching - this results in significant savings in the runtime costs of the program, apart from the obvious cut in the number of lines of code. The basic philosophy of the template based program has not changed from the non-template version - it uses functional composition involving the function call operator, and operator overloading to support naturally combining simple expressions into arbitrarily complex expressions. However, the use of templates allows a fair bit of this work to be done at compile time. This is all that is there to Expression templates. You can put together the expression framework using code given here and use the following function to test your code.

Show line numbers
 int main()
{
Variable x;

cout << ((2.0*x*x + 3.0*x + 3.0)*(2.0*x*x + 3.0*x + 3.0))(2) << endl;
cout << integrate((2.0*x*x + 3.0*x + 3.0), 0, 1) << endl;
cout << integrate((2.0*x*x + 3.0*x + 3.0)*(2.0*x*x + 3.0*x + 3.0), 0, 1) << endl;
cout << integrate((x/(1.0+x)), 0, 1) << endl;
cout << integrate(x*x, 0, 7) << endl;

return 0;
}


If you have any trouble understanding and compiling the code, drop me a message.

References



  1. Todd Veldhuizen: Expression Templates
    http://ubiety.uwaterloo.ca/~tveldhui/papers/Expression-Templates/exprtmpl.html
  2. Expression Templates synopsis on More C++ Idioms
    http://en.wikibooks.org/wiki/More_C%2B%2B_Idioms/Expression-template
  3. Angelika Langer: Expression Templates - Introduction
    http://www.angelikalanger.com/Articles/Cuj/ExpressionTemplates/ExpressionTemplates.htm

4 comments:

Raghavender Boorgapally said...

Thanks man, your tutorial is very lucid and easy to understand. I was desperately looking for such a tutorial on expression templates. Thaks again.

Carl said...

Hi,

Thanks for an informative article. I'm curious to know if there is a way to generalize this to multiple variables?

Ex :
Variable x,y;
cout << (x * y)(2.0); //prints 4.0

Is there a way I can make it print 6.0 for :
cout<< (x * y)(2.0, 3.0);

Arindam Mukherjee said...

Hmm, good question and I did ponder about it - if only a little. I haven't had the time or chance to google it - but off the top of my head what I can think is this:

Whenever we talk of a function of N variables, aHmm, good question and I did ponder about it - if only a little. I haven't had the time or chance to google it - but off the top of my head what I can think is this:

Whenever we talk of a function of N variables, all our expressions should be modelled in that space. So in an (x,y) plane, the expression x can perhaps be thought of as the complex expression f(x,0) where:

f(x,y) = x+y

In other words, we would try to model the expression x as a function of two variables. Likewise for everything. A constant would become:

f(x,y)=4

for all values of x and y. And therefore each and every overloaded operator() should take two double arguments instead of 1. The Variable class could become slightly more complicated:

struct XAxis {};
struct YAxis {};

template <class T>
struct Variable {
};

template<>
struct Variable<XAxis> {
double operator()(double x, double y) {
return x;
}
};

template><
struct Variable<YAxis> {
double operator()(double x, double y) {
return y;
}
};

So your would declare your basic independent variables as:

Variable<XAxis> x;
Variable<YAxis> y;

Arindam Mukherjee said...

Well one more important thing that I missed - you will need to change the operator() signature for Constant and also for ComplexExpression:


struct Constant {
Constant(double d) : d_(d) { }
Constant(int d) : d_(d) { }
double operator() (double, double) {
return d_;
}

double d_;
};

And for the ComplexExpression operator() - it should be like:

double operator() (double x, double y) {
return Op::apply(l_(x, y), r_(x, y));
}

I am yet to test any of this for lack of time; if you find time and are able to ... please let me know. Meanwhile, I will try to put this together for a third part to this series. Thanks a lot!