-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathfold.cxx
115 lines (92 loc) · 3.96 KB
/
fold.cxx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <vector>
#include <algorithm>
#include <cstdio>
#include <cmath>
inline int fact(int x) {
// Use a fold expression to compute factorials. This evaluates the product
// of integers from 1 to x, inclusive.
return (... * @range(1:x+1));
}
inline void func() {
std::vector<int> v { 4, 2, 2, 2, 5, 1, 1, 9, 8, 7, 1, 7, 4, 1 };
// (... || pack) is a short-circuit fold on operator||.
bool has_five = (... || (5 == v[:]));
printf("has_five = %s\n", has_five ? "true" : "false");
bool has_three = (... || (3 == v[:]));
printf("has_three = %s\n", has_three ? "true" : "false");
// Reduce the number of 1s.
int num_ones = (... + (int)(1 == v[:]));
printf("has %d ones\n", num_ones);
// Find the max element using qualified lookup for std::max.
int max_element = (... std::max v[:]);
printf("max element = %d\n", max_element);
// Find the min element using the ADL trick. This uses unqualified lookup
// for min.
using std::min;
int min_element = (... min v[:]);
printf("min element = %d\n", min_element);
// Find the biggest difference between consecutive elements.
int max_diff = (... std::max (abs(v[:] - v[1:])));
printf("max difference = %d\n", max_diff);
// Compute the Taylor series for sign. s is the current index, so
// pow(-1, s) alternates between +1 and -1.
// The if clause in the for-expression filters out the even elements,
// where are zero for sine, and leaves the odd powers. This compacts the
// vector to 5 elements out of 10 terms.
int terms = 10;
std::vector series = [for i : terms if 1 & i => pow(-1, i/2) / fact(i)...];
printf("series:\n");
printf(" %f\n", series[:])...;
// Compute x raised to each odd power. Use @range to generate all odd
// integers from 1 to terms, and raise x by that.
double x = .3;
std::vector powers = [pow(x, @range(1:terms:2))...];
printf("powers:\n");
printf(" %f\n", powers[:])...;
// Evaluate the series to approximate sine. This is a simple dot
// product between the coefficient and the powers vectors.
double sinx = (... + (series[:] * powers[:]));
printf("sin(%f) == %f\n", x, sinx);
}
@meta func();
int main() {
std::vector<int> v { 4, 2, 2, 2, 5, 1, 1, 9, 8, 7, 1, 7, 4, 1 };
// (... || pack) is a short-circuit fold on operator||.
bool has_five = (... || (5 == v[:]));
printf("has_five = %s\n", has_five ? "true" : "false");
bool has_three = (... || (3 == v[:]));
printf("has_three = %s\n", has_three ? "true" : "false");
// Reduce the number of 1s.
int num_ones = (... + (int)(1 == v[:]));
printf("has %d ones\n", num_ones);
// Find the max element using qualified lookup for std::max.
int max_element = (... std::max v[:]);
printf("max element = %d\n", max_element);
// Find the min element using the ADL trick. This uses unqualified lookup
// for min.
using std::min;
int min_element = (... min v[:]);
printf("min element = %d\n", min_element);
// Find the biggest difference between consecutive elements.
int max_diff = (... std::max (abs(v[:] - v[1:])));
printf("max difference = %d\n", max_diff);
// Compute the Taylor series for sign. s is the current index, so
// pow(-1, s) alternates between +1 and -1.
// The if clause in the for-expression filters out the even elements,
// where are zero for sine, and leaves the odd powers. This compacts the
// vector to 5 elements out of 10 terms.
int terms = 10;
std::vector series = [for i : terms if 1 & i => pow(-1, i/2) / fact(i)...];
printf("series:\n");
printf(" %f\n", series[:])...;
// Compute x raised to each odd power. Use @range to generate all odd
// integers from 1 to terms, and raise x by that.
double x = .3;
std::vector powers = [pow(x, @range(1:terms:2))...];
printf("powers:\n");
printf(" %f\n", powers[:])...;
// Evaluate the series to approximate sine. This is a simple dot
// product between the coefficient and the powers vectors.
double sinx = (... + (series[:] * powers[:]));
printf("sin(%f) == %f\n", x, sinx);
}