Monday, December 26, 2016

Understanding Fold Expressions

C++17 has an interesting new feature called fold expressions. Fold expressions offer a compact syntax to apply a binary operation to the elements of a parameter pack. Here’s an example.
template <typename... Args>
auto addall(Args... args) 
{
  return (... + args);
}
addall(1,2,3,4,5); // returns 15.
This particular example is a unary left fold. It's equivalent to ((((1+2)+3)+4)+5). It reduces/folds the parameter pack of integers into a single integer by applying the binary operator successively. It's unary because it does not explicitly specify an init (a.k.a. identity) argument. So, let add it.
template <typename... Args>
auto addall(Args... args) 
{
  return (0 + ... + args);
}
addall(1,2,3,4,5); // returns 15.
This version of addall is a binary left fold. The init argument is 0 and it's redundant (in this case). That's because this fold expression is equivalent to (((((0+1)+2)+3)+4)+5). Explicit identity elements will come in handy a little later---when we have empty parameter packs or if we use user-defined types in fold expressions.

Fold expressions can be defined over a number of operators. 32 to be precise. They are + - * / % ^ & | = < > << >> += -= *= /= %= ^= &= |= <<= >>= == != <= >= && || , .* ->*.

In this post you will see an example of each and see how each one behaves. So here's the whole enchilada.
#include <iostream>
#include <iomanip>

#define UNARY_LEFT_FOLD(NAME, OP)   \
template<typename... Args>          \
auto NAME(Args&&... args) {         \
  return (... OP args);             \
}

UNARY_LEFT_FOLD(add,+);
UNARY_LEFT_FOLD(sub,-);
UNARY_LEFT_FOLD(mul,*);
UNARY_LEFT_FOLD(divide,/);
UNARY_LEFT_FOLD(mod,%);
UNARY_LEFT_FOLD(bxor,^);
UNARY_LEFT_FOLD(band,&);
UNARY_LEFT_FOLD(bor,|);
UNARY_LEFT_FOLD(assign,=);
UNARY_LEFT_FOLD(lt,<);
#ifndef __clang__ 
UNARY_LEFT_FOLD(gt,>); 
UNARY_LEFT_FOLD(rshift,>>); 
#endif
UNARY_LEFT_FOLD(lshift,<<);
UNARY_LEFT_FOLD(addassign,+=);
UNARY_LEFT_FOLD(subassign,-=);
UNARY_LEFT_FOLD(mulassign,*=);
UNARY_LEFT_FOLD(divassign,/=);
UNARY_LEFT_FOLD(modassign,%=);
UNARY_LEFT_FOLD(bxorassign,^=);
UNARY_LEFT_FOLD(bandassign,&=);
UNARY_LEFT_FOLD(borassign,|=);
UNARY_LEFT_FOLD(lshiftassign,<<=);
UNARY_LEFT_FOLD(rshiftassign,>>=);
UNARY_LEFT_FOLD(equals,==);
UNARY_LEFT_FOLD(nequals,!=);
UNARY_LEFT_FOLD(lte,<=);
UNARY_LEFT_FOLD(gte,>=);
UNARY_LEFT_FOLD(land,&&);
UNARY_LEFT_FOLD(lor,||);
UNARY_LEFT_FOLD(objptrmem,.*);
UNARY_LEFT_FOLD(ptrptrmem,->*);

template<typename... Args>
auto comma(Args&&... args) {
  return (... , args);
}

struct Phone  { int ext; };
struct Person { Phone phone;  };

int main(void) 
{
  std::cout << std::boolalpha;
  std::cout << "add "            << add(1)           << " " << add(1,2,3)        << "\n";// 1
  std::cout << "sub "            << sub(1)           << " " << sub(1,2,3)        << "\n";
  std::cout << "mul "            << mul(1)           << " " << mul(1,2,3)        << "\n";
  std::cout << "divide "         << divide(1)        << " " << divide(18,2,3)    << "\n";
  std::cout << "mod "            << mod(1)           << " " << mod(23, 3,2)      << "\n";
  std::cout << "bxor "           << bxor(1)          << " " << bxor(1,2,4)       << "\n";
  std::cout << "band "           << band(1)          << " " << band(1,3,7)       << "\n";
  std::cout << "assign "         << assign(1)        << " " << assign(1,2,4)     << "\n";
    
  auto a = 99; 
  std::cout << "assign-a "       << assign(a);
  std::cout << " "               << assign(a,2,4);
  std::cout << " "               << a << "\n";
    
  #ifndef __clang__ 
  std::cout << "gt "             << gt(1)          << " " << gt(3,2,0)         << "\n"; 
  std::cout << "rshift "         << rshift(1)        << " " << rshift(32,2,2)    << "\n"; 
  #endif

  std::cout << "lt "             << lt(1)            << " " << lt(1,2,-1)         << "\n"; 
  std::cout << "lshift "         << lshift(1)        << " " << lshift(1,2,3)     << "\n";
  std::cout << "addassign "      << addassign(1)     << " " << addassign(2,3,2)  << "\n";
  std::cout << "subassign "      << subassign(1)     << " " << subassign(7,2)    << "\n";
  std::cout << "mulassign "      << mulassign(1)     << " " << mulassign(2,3,2)  << "\n";
  std::cout << "divassign "      << divassign(1)     << " " << divassign(7,2)    << "\n";
  std::cout << "modassign "      << modassign(1)     << " " << modassign(23,3,2) << "\n";
  std::cout << "bxorassign "     << bxorassign(1)    << " " << bxorassign(7,2)   << "\n";
  std::cout << "bandassign "     << bandassign(1)    << " " << bandassign(7,6)   << "\n";
  std::cout << "borassign "      << borassign(1)     << " " << borassign(1,2,4,8) << "\n";
  std::cout << "lshiftassign "   << lshiftassign(1)  << " " << lshiftassign(8,2)  << "\n";
  std::cout << "rshiftassign "   << rshiftassign(1)  << " " << rshiftassign(16,1,2)   << "\n";
  std::cout << "equals "         << equals(1)        << " " << equals(8,3,2)     << "\n";
  std::cout << "nequals "        << nequals(1)       << " " << nequals(7,2,0)    << "\n";
  std::cout << "lte "            << lte(1)           << " " << lte(7,2,0)        << "\n";
  std::cout << "gte "            << gte(1)           << " " << gte(7,3,1)        << "\n";
  std::cout << "land "           << land()           << " " << land(7,2)         << "\n";
  std::cout << "lor "            << lor()            << " " << lor(7,2)          << "\n";
  std::cout << "comma "          << comma(1)         << " " << comma(8,3,2)      << "\n";
  
  auto phoneptr = &Person::phone;
  auto extptr = &Phone::ext;
  Person p { { 999 } };
  Person * pptr = &p;
  std::cout << "objptrmem "                   << objptrmem(p,phoneptr,extptr)       << "\n";
  std::cout << "p.*phoneptr.*extptr "         << p.*phoneptr.*extptr                << "\n";
  std::cout << "ptrptrmem(&p,phoneptr).ext "  << ptrptrmem(&p,phoneptr).ext         << "\n";  
  std::cout << "&(pptr->*phoneptr)->*extptr " << (&(pptr->*phoneptr))->*extptr      << "\n";

}
The output looks something like the following.
add 1 6
sub 1 -4
mul 1 6
divide 1 3
mod 1 0
bxor 1 7
band 1 1
assign 1 4
assign-a 99 4 4
gt 1 true
rshift 1 2
lt 1 false
lshift 1 32
addassign 1 7
subassign 1 5
mulassign 1 12
divassign 1 3
modassign 1 0
bxorassign 1 5
bandassign 1 6
borassign 1 15
lshiftassign 1 32
rshiftassign 1 2
equals 1 false
nequals 1 true
lte 1 true
gte 1 true
land true true
lor false true
comma 1 2
objptrmem 999
p.*phoneptr.*extptr 999
ptrptrmem(&p,phoneptr).ext 999
&(pptr->*phoneptr)->*extptr 999
There're a number of observations.
  1. Clang does not like > and >> operators for some reason. GCC is fine.
  2. Unary fold expressions do not like empty parameter packs except for && || and comma operators. In fact, the P0036 document describes what happens when empty parameter packs are used with these operators and why it's illegal for other operators. In short, empty parameter packs result into true, false, and void() respectively. In that sense, binary folds appear significantly superior because you can specify the identity element for fundamental and user-defined types and for all the operators.
  3. Single element parameter packs result into the value of the element type. This may be ok for some types and operators but it's very confusing for operators such as > < == != <= >= && ||. These operators return boolean result in general but not when the parameter pack has only one element. The type of the expression changes when the size of the parameter pack is greater than 1. For example, lte(1) returns a int but lte(1,3) return a boolean. That's bizarre.
  4. Multiple element parameter packs work as expected with a twist. Consider gt example on line #73. gt(3,2,0) expands to (3>2)>0, which is true>0, which is true. Similarly, lt(1,0,-1) is (1<0)<-1, which is false<-1, which is false. However, for such operators (that return a boolean result), compiler spits out copious amount of warnings saying that "comparisons like 'X<=Y<=Z' do not have their mathematical meaning". That makes sense.
  5. The assign function is curious too. Assigning to a variable makes sense. For example, assign(a,2,4) expands to (a=2)=4, which assigns 2 to a and later 4 to a. So there're two assignments. The result type is int&. The funny thing is that if you replace a with an rvalue, it still works. I don't know what the compiler is thinking at that point.
  6. Operator associativity has no consequence. For example, <<= and >>= are right-associative operators but left folds still fold from left to right. I.e., Nominally, a <<= b <<= c is equivalent to a <<= (b <<= c). With left unary fold you get (a <<= b) <<= c. If you want the former, use a unary right fold.
  7. Finally, consider the folds expressions containing pointer to members. Line #103 and below. A single, initialized pointer to member just a decays to true in a boolean context (like any other pointer). The weird thing though is that, there's no way to make sense of two or more pointers to members. I can't think of a way where they fold (a.k.a. compose) and return something meaningful. An object (of the same class as that of the member pointer) is required as the left most element in the parameter pack to deference a list of member pointers. For example, objptrmem(p,phoneptr,extptr) is the same as p.*phoneptr.*extptr. Without p, just phoneptr and extptr make no sense together.


Binary Folds

This example uses a user-defined Int type in a left binary fold. We'll also specify our own identity for our Int-based binary folds.
#include <iostream>
#include <iomanip>

struct Int {
  int value;
  explicit Int(int v=0) : value(v) {}
  explicit operator int() const { return value; }
};

std::ostream& operator << (std::ostream& o, const Int& i) {
   o << i.value;
   return o;
}

Int operator + (Int const &i, Int const &j) {
  std::cout << "Adding " << i << " " << j << "\n";
  return Int(i.value + j.value);  
};

Int operator * (Int const &i, Int const &j) {
  std::cout << "Multiplying " << i << " " << j << "\n";
  return Int(i.value * j.value);  
};

template<typename... Args>
auto addInts(Args&&... args) 
{
  return (Int{0} + ... + args);
}

template<typename... Args>
auto mulInts(Args&&... args) 
{
  return (Int{1} * ... * args);
}

int main(void)
{
  std::cout << addInts(Int{1}, Int{2}, Int{3}) << "\n"; // prints 6
  std::cout << addInts() << "\n"; // prints 0
  std::cout << mulInts(Int{1}, Int{2}, Int{3}) << "\n"; // prints 6
  std::cout << mulInts() << "\n"; // prints 1
}
Things are very much as expected in this example. For user-defined types, the operator you wish to use fold expression with must be overloaded. Int overloads binary + and *. addInts uses Int{0} as the identity element whereas mulInts uses Int{1}. Identity element is special. It's special because in case of Int addition, adding with identity element make no difference. Similarly, in Integer multiplication, multiplying with the identity element makes no difference.

I'll wrap with a quick theory about monoids.

Formally, (Int,+) is monoid with Int{0} as identity and (Int,*) is also a (different) monoid with Int{1} as identity. Two instances of the same monoid can be combined to produce a third one. In fact, Monoids can be combined arbitrarily to produce other instances of the same monoid. Left and right folds provide just 2 possible ways in which any monoid may be combined.

In the following posts, we'll create more interesting monoids and see how well fold expressions can exploit their properties.

1 comment:

korax said...

Do you know, why only operators are supported in folds? const_expr functions would be nice and in many cases more expressive (name).