'Extending expression templates for automatic differentiation
I am trying to understand expression templates to be able to efficiently evaluate multivariate polynomials like:
double PAx,PAy, aAB;
double val = (PAx*PAx*PAx*PAx + 6*PAx*PAX*aAB + 3*aAB*aAB)*(PAy*PAy*PAy*PAy);
using expression templates, if possible. How do I modify the following code snippet to achieve this goal?
#include <utility>
#include <iostream>
template<typename T1, typename T2>
struct plus;
template<typename T1, typename T2>
struct times;
template<typename T, typename U>
constexpr auto operator+(T l, U r){ return plus(l,r);}
template<typename T, typename U>
constexpr auto operator*(T l, U r){ return times(l,r);}
template<typename T1, typename T2>
struct plus{
constexpr plus(T1 t1_, T2 t2_):t1{t1_},t2{t2_}{}
T1 t1;
T2 t2;
template<typename T>
constexpr auto operator()(T&& t) const {return t1(std::forward<T>(t))+t2(std::forward<T>(t));}
};
template<typename T>
struct is_plus : std::false_type{};
template<typename T1, typename T2>
struct is_plus<plus<T1, T2>> : std::true_type{};
template<typename T1, typename T2>
struct times{
constexpr times(T1 t1_, T2 t2_):t1{t1_},t2{t2_}{}
T1 t1;
T2 t2;
template<typename T>
constexpr auto operator()(T&& t) const {return t1(std::forward<T>(t))*t2(std::forward<T>(t));}
};
template<typename T>
struct is_times : std::false_type{};
template<typename T1, typename T2>
struct is_times<times<T1, T2>> : std::true_type{};
template <typename T, typename U>
using same = std::is_same<typename std::decay<T>::type,typename std::decay<U>::type>;
int main(){
constexpr auto x = [](auto t){ return [t](auto t2){return t ;};};
constexpr auto y = [](auto t){ return [t](auto t2){return t2;};};
constexpr auto c = [](auto t){ return [t](auto){ return [t](auto){return t;};};};
constexpr auto zero = [](auto t){ return [t](auto t2){return 0 ;};};
constexpr auto one = [](auto t){ return [t](auto t2){return 1 ;};};
constexpr auto recursion = [one, zero, x, y](auto tag, auto t, auto self){
using t_t = decltype(t);
using x_t = decltype(x);
using y_t = decltype(y);
using tag_t = decltype(tag);
if constexpr (same<t_t,x_t>::value)
if constexpr (same<tag_t,x_t>::value)
return one;
else
return zero;
else if constexpr(same<t_t, y_t>::value)
if constexpr (same<tag_t,y_t>::value)
return one;
else
return zero;
else if constexpr( is_plus<typename std::decay<decltype(t)>::type>::value)
return plus{self(tag, t.t1, self), self(tag, t.t2, self)};
else if constexpr( is_times<typename std::decay<decltype(t)>::type>::value)
return plus{times{self(tag, t.t1, self), t.t2}, times{t.t1, self(tag, t.t2, self)}};
else
return zero;
};
constexpr auto Dx = [recursion,x](auto t){return recursion(x, t, recursion);};
constexpr auto Dy = [recursion,y](auto t){return recursion(y, t, recursion);};
// Having defined variables like:
double PAx,PAy,aAB;
auto constexpr ex = Dy(x*x*c(4.)+y*x*c(4.)); // I am interested in evaluating expressioons like (PAx*PAx*PAx*PAx + 6*PAx*PAx*aAB + 3*aAB*aAB)*(PAy*PAy*PAy*PAy)
auto constexpr val1=ex(5.);
auto constexpr val=(val1+x(1.))(0.);
std::cout<<val<<"\n";
}
The fact that in my application PAx
, PAy
and aAB
, are only known at execution time can have a negative impact on the performance improvement that I am expecting? Can lazy evaluation be preserved?
The snippet is based on this tutorial I tested the snippet of code above on godbolt.org and here you can see the instantiation of the templates specializations.
Update Taking into account the suggestions in the comments below, the code was modified as follows:
#include <utility>
#include <tuple>
#include <functional>
#include <iostream>
#include <cmath>
namespace expr_types {
template<class D>
struct formula {
constexpr auto operator()(auto...args) const {
return eval(*static_cast<D const*>(this))(args...);
}
};
template<std::size_t I>
struct var_t:formula<var_t<I>> {
friend constexpr auto eval( var_t ) {
return [](auto...args){return std::get<I>( std::make_tuple(args...) );};
}
};
template<std::size_t N>
constexpr var_t<N> var = {};
template<class T=double>
struct val_t:formula<val_t<T>> {
T k = {};
constexpr val_t(T in):k(in){}
friend constexpr auto eval( val_t self ) {
return [self](auto...){return self.k;};
}
friend constexpr auto diff( auto, val_t ) {
return val_t{0};
}
};
template<class T>
val_t(T)->val_t<T>;
template<class Lhs, class Op, class Rhs>
struct tree:formula<tree<Lhs, Op, Rhs>> {
Lhs lhs;
Op op;
Rhs rhs;
constexpr tree(Lhs l, Op o, Rhs r):lhs(l),op(o),rhs(r){}
friend constexpr auto eval( tree self ) {
return eval(self.lhs, self.op, self.rhs);
}
friend constexpr auto diff( auto D, tree self ) {
return diff(D, self.lhs, self.op, self.rhs);
}
};
auto operator+( auto lhs, auto rhs ) {
return tree{lhs, std::plus<>{}, rhs};
}
auto operator-( auto lhs, auto rhs ) {
return tree{lhs, std::minus<>{}, rhs};
}
auto operator*( auto lhs, auto rhs ) {
return tree{lhs, std::multiplies<>{}, rhs};
}
auto operator/( auto lhs, auto rhs ) {
return tree{lhs, std::divides<>{}, rhs};
}
struct pow_t {
auto operator()(auto lhs, auto rhs)const{
return std::pow(lhs, rhs);
}
};
auto operator^( auto lhs, auto rhs ) {
return tree{lhs, pow_t{}, rhs};
}
// what * means:
constexpr auto eval( auto lhs, auto op, auto rhs )
{
return [=](auto...args){
// use std::multiplies to avoid seeing above operator*
return op(eval(lhs)(args...), eval(rhs)(args...));
};
}
// What D(a*b) means:
constexpr auto diff(auto D, auto lhs, std::multiplies<>, auto rhs) {
return D(lhs)*rhs + lhs*D(rhs);
}
// What D(a/b) means:
constexpr auto diff(auto D, auto lhs, std::divides<>, auto rhs) {
return (D(lhs)*rhs - lhs*D(rhs)) / ( rhs^val_t{2} );
}
// What D(a+b) means:
constexpr auto diff(auto D, auto lhs, std::plus<>, auto rhs) {
return D(lhs)+D(rhs);
}
// What D(a-b) means:
constexpr auto diff(auto D, auto lhs, std::minus<>, auto rhs) {
return D(lhs)-D(rhs);
}
template<class V, class T=double>
struct differential_t {
constexpr auto operator()( auto f ) const {
return diff( *this, f );
}
};
template<std::size_t N, class T, std::size_t M>
constexpr val_t<T> diff( differential_t<var_t<N>, T>, var_t<M> ) {
return val_t<T>{N==M};
}
}
namespace expr {
constexpr auto x = expr_types::var<0>;
constexpr auto y = expr_types::var<1>;
constexpr auto val(auto x){
using namespace expr_types;
return val_t(x);
}
constexpr auto one = val(1);
constexpr auto zero = val(0);
template<auto V, class T=double>
constexpr expr_types::differential_t<decltype(V),T> D = {};
}
int main()
{
using namespace expr;
auto f = x*one + x*x + x*x*y + zero;
std::cout << f(1.0, 2.0) << "\n";
auto df = D<x>(f);
std::cout << df(1.0, 2.0) << "\n";
auto g = (one + x)*(one + y);
std::cout << g(1.0, 1.0) << "\n";
std::cout << D<y>(g)(1.0, 1.0) << "\n";
auto h = (x^val(2));
std::cout << h(2.0, 2.0) << "\n";
}
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|