-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunroll.py
135 lines (109 loc) · 3.05 KB
/
unroll.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def make_macro(name, body):
body = body.replace('\n', '\\\n')
return f"#define {name.upper()} {body}\n"
def update_paren():
return """
paren = next_paren_bitmask(paren);
"""
def set_curr(lsize, voff, prev_start):
ret = ""
prev_end = prev_start + lsize
if prev_end > 64 and prev_start < 32:
ret += """
// previous started at {} and ended at {}
curr |= (paren >> {}) << {};
""".format(prev_start, prev_end, 32 - prev_start, 0)
if voff < 32:
ret += update_paren()
ret += """
// goes up to {}
curr |= paren << {};
""".format(voff + lsize, voff)
voff += lsize
else:
assert lsize >= voff
# ret += "\n\t\tprint_bits(curr);"
voff -= 32
return (ret, voff)
def broadcast_and_write(lsize, voff, unroll_idx):
comment = f"// unroll idx of {unroll_idx} with voff {voff}"
vals = ", ".join(["'\\n'" if (x + 1) == voff else "0x28" for x in reversed(range(32))])
bcv = f"_mm256_set_epi8({vals})"
return """
{}
// only need the low 32 bits of each lane set, but this is fine
resv = _mm256_set1_epi32(curr);
// move the byte of paren that has the bit in the corresponding
// position in the vector to that position.
resv = _mm256_shuffle_epi8(resv, shufmask);
// only let the correct bit be set
resv = _mm256_and_si256(resv, andmask);
// set all nonzero bytes to -1
// reuse andmask because it's a superset of resv
resv = _mm256_cmpeq_epi8(resv, andmask);
// combine with bytecode
resv = _mm256_sub_epi8({}, resv);
_mm256_store_si256((__m256i *) cursor, resv);
cursor += 32;
curr >>= 32;
""".format(comment, bcv)
def loop_tail(unroll):
return ("""
if (i >= PIPECNT || paren == FIN) {
currbuf = flush_buf(cursor - currbuf, currbuf);
cursor = currbuf;
i = 0;
if (paren == FIN) {
return;
}
}
"""
f"i += {unroll};"
)
def loop_pre():
return """
ERROR("things to do");
int i;
__m256i resv;
uint64_t curr;
char *cursor, *currbuf;
const __m256i shufmask = _mm256_set_epi64x(
0x0303030303030303,
0x0202020202020202,
0x0101010101010101,
0x0000000000000000);
const __m256i andmask = _mm256_set1_epi64x(0x8040201008040201);
cursor = buf;
currbuf = buf;
i = 0;
curr = paren;
"""
def function_body(lsize):
ret = loop_pre()
prev_voff = 0
voff = lsize
ret += "do {"
for idx in range(lsize):
advance_paren, next_voff = set_curr(lsize, voff, prev_voff)
ret += advance_paren
ret += broadcast_and_write(lsize, voff, idx)
ret += loop_tail(1)
prev_voff = voff
voff = next_voff
ret += update_paren();
ret += "\ncurr = paren;\n"
ret += "} while (paren != FIN);\n"
return ret
def function(lsize):
return ("""
static void
do_batch_unrolled(uint64_t paren) {
"""
f"{function_body(lsize)}"
"""
}
\n""")
def main():
with open("unrolled.generated.h", "w") as fd:
fd.write(function(41))
main()