'Automatic differentiation using expression templates c++
Introduction
I am trying to learn about expression templates because it seems to be a very powerful technique for a wide range of calculations. I looked at different examples online (e.g. wikipedia), and I wrote a bunch of small programs that do different calculations. However, I found a problem that I can't solve; namely, the assignment of an arbitrary expression to a variable, its "lazy evaluation", and its re-assignment to another arbitrary expression.
One can use auto
, in order to assign an expression to a variable, but reassignment to another expression is impossible (you can have a look at the entire library sadET, in order to fully understand what I am trying to do). Also, the assignment and reassignment can be done using CRTP, and overloading operator=
. However, the expression is evaluated during the assignment, and we basically lose all information on what the expression is.
Therefore, I tried to use Polymorphic cloning + CRTP (e.g. see this), which kind of works, but I get a segfault when I try to reassign the variable to an expression that contains the same variable.
Code
In the code below, I show a simplified version of how I implement the addition expression template. In the code that follows, msg
is a macro for std::cout<<typeid(*this).name()<<std::endl;
, in order to keep track of what is being evaluated.
This is the (pure) base class of all expressions, which will allow me to assign a variable to a generic expression (using polymorphic cloning):
struct BaseExpression{
BaseExpression()=default;
virtual double evaluate()const =0 ;
};
This is a class from which all expressions inherit, and will allow me to use CRTP
template<typename subExpr>
struct GenericExpression:BaseExpression{
const subExpr& self() const {return static_cast<const subExpr&>(*this);}
subExpr& self() {return static_cast<subExpr&>(*this);}
double evaluate() const { msg; return self().evaluate(); };
};
The idea is that all expressions are made of numbers, elementary functions, and operators. So, I write a Number
class as follows
class Number: public GenericExpression<Number>{
double val;
public:
Number()=default;
Number(const double &x):val(x){}
Number(const Number &x):val(x.evaluate()){}
double evaluate()const { msg; return val;}
double& evaluate() { msg; return val;}
};
Following the ideas of expression templates, then, the addition is
template<typename leftHand,typename rightHand>
class Addition:public GenericExpression<Addition<leftHand,rightHand>>{
const leftHand &LH;
const rightHand &RH;
public:
Addition(const leftHand &LH, const rightHand &RH):LH(LH),RH(RH){}
double evaluate() const {msg; return LH.evaluate() + RH.evaluate();}
};
template<typename leftHand,typename rightHand>
Addition<leftHand,rightHand>
operator+(const GenericExpression<leftHand> &LH, const GenericExpression<rightHand> &RH){
return Addition<leftHand,rightHand>(LH.self(),RH.self());
}
In order to be able to utilize BaseExpression
, I also write an Expression
class that will be used to assign instances of GenericExpression
to Expression
variables.
class Expression: public GenericExpression<Expression>{
public:
BaseExpression *baseExpr;
Expression()=default;
Expression(const Expression &E){baseExpr = E.baseExpr;};
double evaluate() const {msg; return baseExpr->evaluate();}
template<typename subExpr>
void assign(const GenericExpression<subExpr> &RH){
baseExpr = new subExpr(RH.self());
}
};
I this class, the important bit is the pointer baseExpr
that allows to change the evaluate
function to the GenericExpression
when we call assign.
Some examples
In order to test whether this worked, I declare the following variables:
Number x(3.2);
Number y(-2.3);
Expression z,w;
Then, we can see that the following things work
//assignment to Number
z.assign(x);
cout<<z.evaluate()<<endl;
//assignment to Addition<Number,Number>
z.assign(x+y);
cout<<z.evaluate()<<endl;
//assignment to Addition<Expression,Number>
w.assign(z+y);
cout<<w.evaluate()<<endl;
However, when I do the following, I get an infinite recursion when I run z.evaluate()
since z.baseExpr
points to itself.
// Segmentation fault of z.evaluate() due to infinite recursion between
// LH.evaluate() (in Addition<Expression,Number>::evaluate()) and
// baseExpr->evaluate() (in Expression::evaluate())
z.assign(z+x);
cout<<z.evaluate()<<endl;
The entire code that reproduces the behavior I describe:
#include<iostream>
#include<cmath>
// message to print during evaluation, in order to track the evaluation path.
#define msg std::cout<<typeid(*this).name()<<std::endl;
struct BaseExpression{
BaseExpression()=default;
virtual double evaluate()const =0 ;
};
template<typename subExpr>
struct GenericExpression:BaseExpression{
const subExpr& self() const {return static_cast<const subExpr&>(*this);}
subExpr& self() {return static_cast<subExpr&>(*this);}
double evaluate() const { msg; return self().evaluate(); };
};
class Number: public GenericExpression<Number>{
double val;
public:
Number()=default;
Number(const double &x):val(x){}
Number(const Number &x):val(x.evaluate()){}
double evaluate()const { msg; return val;}
double& evaluate() { msg; return val;}
};
template<typename leftHand,typename rightHand>
class Addition:public GenericExpression<Addition<leftHand,rightHand>>{
const leftHand &LH;
const rightHand &RH;
public:
Addition(const leftHand &LH, const rightHand &RH):LH(LH),RH(RH){}
double evaluate() const {msg; return LH.evaluate() + RH.evaluate();}
};
template<typename leftHand,typename rightHand>
Addition<leftHand,rightHand>
operator+(const GenericExpression<leftHand> &LH, const GenericExpression<rightHand> &RH){
return Addition<leftHand,rightHand>(LH.self(),RH.self());
}
class Expression: public GenericExpression<Expression>{
public:
BaseExpression *baseExpr;
Expression()=default;
Expression(const Expression &E){baseExpr = E.baseExpr;};
// Expression(Expression *E){baseExpr = E->baseExpr;};
double evaluate() const {msg; return baseExpr->evaluate();}
template<typename subExpr>
void assign(const GenericExpression<subExpr> &RH){
baseExpr = new subExpr(RH.self());
}
};
using std::cout;
using std::endl;
int main(){
Number x(3.2);
Number y(-2.3);
Expression z,w;
// works fine!
z.assign(x);
cout<<z.evaluate()<<endl;
// works fine!
z.assign(x+y);
cout<<z.evaluate()<<endl;
// works fine!
w.assign(z+y);
cout<<w.evaluate()<<endl;
// Segmentation fault of z.evaluate() infinite recursion in between
// LH.evaluate() (in Addition<Expression,Number>::evaluate()) and
// baseExpr->evaluate() (in Expression::evaluate())
z.assign(z+x);
cout<<z.evaluate()<<endl;
return 0;
}
Moving to autodiff
Implementation of automatic differentiation is not that different than the example above. The most important changes are the introduction of a Constant
class and a derivative
member function to all classes derived from GenericExpression
, that returns an arbitrary expression. That is, we add
class Constant: public GenericExpression<Constant>{
double val;
public:
Constant()=default;
Constant(const double &x):val(x){}
Constant(const Constant &x):val(x.evaluate()){}
double evaluate()const { msg; return val;}
double& evaluate() { msg; return val;}
auto derivative(){return Constant(0);}
};
The derivative
member in the other classes should look like:
auto Number::derivative(){return Constant(1);}
template<typename leftHand,typename rightHand>
auto Addition<left,right>::derivative(){return LH.derivative() + RH.derivative();}
The problem
Even if we find a way to avoid (or ignore) the segfault I showed previously, I see no way that I can make derivative
a virtual member funtion of BaseExpression
, since it returns a different kind of expression that, in principle, we only know when we call it (hence the auto
).
Final question
In the context I tried to describe, is there any way I can do the following
Number x(5);
Expression z;
z.assign(x);
z.assign(z+x);
cout<<z.evaluate()<<endl;
cout<<z.derivative().evaluate()<<endl;
preferably without segfault?
I would appreciate the input or insight of the community!
Edit
I simplified the code in order to makes things a bit easier, and avoid potential dangerous pointers.
In this example, I made the member function evaluate
a function pointer to lambdas. This way, and if I understand correctly, I copy directly the ...::evaluate
to Expression::evaluate
. However, I still get Segfaults...
Curiously, I get Segfault when using a recursive function to sum over vector.
New code:
#include<iostream>
#include<functional>
#include<vector>
// message to print during evaluation, in order to track the evaluation path.
#define msg std::cout<<typeid(*this).name()<<std::endl
class Expression{
public:
std::function<double(void)> evaluate;
template<typename T>
Expression(const T &RH){evaluate = RH.evaluate;}
template<typename T>
Expression& operator=(const T &RH){evaluate = RH.evaluate; return *this;}
};
class Number{
double val;
public:
Number()=default;
std::function<double()> evaluate;
Number(const double &x):val(x){ evaluate = [this](){msg; return this->val;}; }
Number(const Number &x):val(x.evaluate()){ evaluate = [this](){msg; return this->val;}; }
};
template<typename leftHand,typename rightHand>
class Addition{
const leftHand &LH;
const rightHand &RH;
public:
std::function<double()> evaluate;
Addition(const leftHand &LH, const rightHand &RH):LH(LH),RH(RH)
{evaluate = [this](){msg; return this->LH.evaluate() + this->RH.evaluate();};}
};
template<typename leftHand,typename rightHand>
auto operator+(const leftHand &LH, const rightHand &RH){return Addition<leftHand,rightHand>(LH,RH); }
using std::cout;
using std::endl;
inline Expression func (std::vector<Expression> x, int i){
cout<<i<<endl;
if (i==0){return static_cast<Expression>(x[0]);}
return static_cast<Expression>( x[i] + func(x,i-1) ) ;
};
int main(){
Number y(-2.);
Number x(1.33);
Expression z(y+x);
Expression w(x+y+x);
// works
z =x;
cout<<z.evaluate()<<endl;
cout<<(z+z+z+z).evaluate()<<endl;
// Segfault due to recusion
// z =z+x;
// cout<<z.evaluate()<<endl;
// Unkown Segfault
// z = x+y ;
// cout<<(z+z).evaluate()<<endl;
// cout<<typeid(z+z).name()<<endl;
// Unkown Segfault
// z = w+y+x+x;
// cout<<z.evaluate()<<endl;
// Unkown Segfault
// std::vector<Expression> X={x,y,x,y,x,y,x,y};
// cout << typeid(func(X,X.size()-1)).name() << endl;
// cout << (func(X,X.size()-1)).evaluate() << endl;
return 0;
}
Update
After around a year, I am revisiting this problem, and I think I have found a solution; use a pure pure abstract class from which every expression is derived. Then everything (from variables, to the return types of operators, and functions) is a pointer of this abstract class. The code I am currently working on is here. Notice however that I don't use expression templates for the new design, but I expect to do it at some point in the near future.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|