Skip to content

Commit 6e57d6c

Browse files
authored
Add a notebook with a visualization of the aprrox_* functions and their errors (#7974)
* Add a notebook with a visualization of the aprrox_* functions and their errors * Fix spelling error
1 parent 9f6ec17 commit 6e57d6c

File tree

2 files changed

+384
-0
lines changed

2 files changed

+384
-0
lines changed

apps/hannk/halide/common_halide.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ Halide::Expr align(const Halide::Expr &x, const Halide::Expr &n);
3939
// where N is the number of bits of the narrowed result minus one.
4040
Halide::Expr multiply_2x_high(const Halide::Expr &a, const Halide::Expr &b);
4141

42+
// For a visualization of the approx_* functions and their errors, see:
43+
// apps/hannk/halide/docs/approx_log2_and_applications.ipynb
4244
// Approximate log2(x/2^q_x)*2^q.
4345
// q must be less than 16.
4446
Halide::Expr approx_log2(int q, const Halide::Expr &x, int q_x, const Halide::Type &type = Halide::Int(32));
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"provenance": []
7+
},
8+
"kernelspec": {
9+
"name": "python3",
10+
"display_name": "Python 3"
11+
}
12+
},
13+
"cells": [
14+
{
15+
"cell_type": "code",
16+
"metadata": {
17+
"id": "r1XiiUQGUjpx"
18+
},
19+
"source": [
20+
"import numpy as np\n",
21+
"import matplotlib as mpl\n",
22+
"import matplotlib.pyplot as plt\n",
23+
"\n",
24+
"# Many architectures have shifts where the right-hand-side is signed. A negative\n",
25+
"# RHS is the same as a positive shift in the other direction.\n",
26+
"def shift_right(x, y):\n",
27+
" return np.floor(x / 2**y)\n",
28+
"def shift_left(x, y):\n",
29+
" return np.floor(x * 2**y)\n",
30+
"def rounding_shift_right(x, y):\n",
31+
" return np.round(x / 2**y)\n",
32+
"def rounding_shift_left(x, y):\n",
33+
" return np.round(x * 2**y)\n",
34+
"\n",
35+
"def bitwise_and(x, y):\n",
36+
" return np.mod(x, y + 1)\n",
37+
"\n",
38+
"# This is sqrdmulh on ARM\n",
39+
"def multiply_2x_high(x, y):\n",
40+
" return rounding_shift_right(x * y, 15)\n",
41+
"\n",
42+
"def relative_error(x, y):\n",
43+
" return (x - y) / (np.maximum(x, y) + 1e-3)\n",
44+
"\n",
45+
"def plot_results(x, exact, approxs, title, logx = False, logy = False, relative = False, log2_xscale = 0, log2_yscale = 0):\n",
46+
" fig, [p1, p2] = plt.subplots(2, 1)\n",
47+
"\n",
48+
" p1.set_xlabel('x')\n",
49+
" if logx:\n",
50+
" p1.set_xscale('log')\n",
51+
" p1.set_ylabel(title)\n",
52+
" if logy:\n",
53+
" p1.set_yscale('log')\n",
54+
"\n",
55+
" xscale = 2**log2_xscale\n",
56+
" yscale = 2**log2_yscale\n",
57+
"\n",
58+
" exact = np.round(exact*yscale)/yscale\n",
59+
"\n",
60+
" p1.plot(x/xscale, exact)\n",
61+
" for approx in approxs:\n",
62+
" p1.plot(x/xscale, approx/yscale)\n",
63+
"\n",
64+
" p2.set_xlabel('x')\n",
65+
" if logx:\n",
66+
" p2.set_xscale('log')\n",
67+
"\n",
68+
" p2.set_ylabel('relative error' if relative else 'error')\n",
69+
" for approx in approxs:\n",
70+
" p2.plot(x/xscale, relative_error(approx/yscale, exact) if relative else approx/yscale - exact)\n",
71+
"\n",
72+
"def eval_poly(x, p, q):\n",
73+
" x1 = rounding_shift_left(x, 15 - q)\n",
74+
" y = p[0]\n",
75+
" xi = x1\n",
76+
" for i in p[1:]:\n",
77+
" y = y + multiply_2x_high(i, xi)\n",
78+
" xi = multiply_2x_high(xi, x1)\n",
79+
" return rounding_shift_right(y, 15 - q)\n",
80+
"\n",
81+
"points = 6\n",
82+
"degree = 3\n",
83+
"log2_poly_x = np.arange(points, 2 * points + 1) / points\n",
84+
"log2_poly_y = np.log2(log2_poly_x)\n",
85+
"log2_poly = np.polyfit(log2_poly_x - 1, log2_poly_y, degree)\n",
86+
"\n",
87+
"exp2_poly_x = np.arange(points, 2 * points + 1) / points\n",
88+
"exp2_poly_y = np.exp2(exp2_poly_x - 1) - 1\n",
89+
"exp2_poly = np.polyfit(exp2_poly_x - 1, exp2_poly_y, degree)\n",
90+
"\n",
91+
"log2_poly = log2_poly[::-1]\n",
92+
"exp2_poly = exp2_poly[::-1]\n",
93+
"\n",
94+
"print(log2_poly)\n",
95+
"print(exp2_poly)\n",
96+
"\n",
97+
"log2_poly = np.round(log2_poly * 2**15)\n",
98+
"exp2_poly = np.round(exp2_poly * 2**15)\n",
99+
"exp2_poly[0] = 0\n",
100+
"\n",
101+
"print(log2_poly)\n",
102+
"print(exp2_poly)"
103+
],
104+
"execution_count": null,
105+
"outputs": []
106+
},
107+
{
108+
"cell_type": "code",
109+
"metadata": {
110+
"id": "1xjo4hIEo_z5"
111+
},
112+
"source": [
113+
"# Approximate N*log2(x*2^q_x), where N = 2^q, and the intermediate computations are\n",
114+
"# restricted to be integers.\n",
115+
"def approx_log2(x, q, q_x = 0):\n",
116+
" # This can be computed with count_leading_zeros\n",
117+
" floor_log2_x = np.select([x > 0], [np.floor(np.log2(x))], [-1])\n",
118+
"\n",
119+
" # We've computed log2(x*2^q_x) = log2(x) + q_x. Subtract that offset now\n",
120+
" # before multiplying by the result quantization.\n",
121+
" result = shift_left(floor_log2_x - q_x, q)\n",
122+
"\n",
123+
" frac = bitwise_and(shift_right(x, floor_log2_x - q), 2**q - 1)\n",
124+
"\n",
125+
" return result + eval_poly(frac, log2_poly, q)\n",
126+
"\n",
127+
"x = np.arange(1, 10000)\n",
128+
"q = 15\n",
129+
"q_x = 2\n",
130+
"log2_x = np.log2(x / 2**q_x)\n",
131+
"approx_log2_x = approx_log2(x, q, q_x)\n",
132+
"\n",
133+
"plot_results(x, log2_x, [approx_log2_x], 'log2(x)', logx=True, log2_xscale=q_x, log2_yscale=q)"
134+
],
135+
"execution_count": null,
136+
"outputs": []
137+
},
138+
{
139+
"cell_type": "code",
140+
"metadata": {
141+
"id": "6uJN5muLsLdE"
142+
},
143+
"source": [
144+
"\n",
145+
"# Approximate 2^(x/2^q_x)*2^q\n",
146+
"def approx_exp2(x, q_x, q):\n",
147+
" int_part = shift_right(x, q_x)\n",
148+
" frac_part = x - shift_left(int_part, q_x)\n",
149+
"\n",
150+
" frac_part = eval_poly(frac_part, exp2_poly, q_x)\n",
151+
"\n",
152+
" exp_int_part = shift_left(1, int_part + q)\n",
153+
" return exp_int_part + rounding_shift_right(exp_int_part * frac_part, q_x)\n",
154+
"\n",
155+
"q_x = 10\n",
156+
"q = 15\n",
157+
"x = np.arange(-4000, 2000)\n",
158+
"approx_exp2_x = approx_exp2(x, q_x, q)\n",
159+
"exact = np.exp2(x / 2**q_x)\n",
160+
"\n",
161+
"plot_results(x, exact, [approx_exp2_x], '2^x', False, True, relative=True, log2_xscale=q_x, log2_yscale=q)\n"
162+
],
163+
"execution_count": null,
164+
"outputs": []
165+
},
166+
{
167+
"cell_type": "code",
168+
"metadata": {
169+
"id": "5BP-edzCmNBi"
170+
},
171+
"source": [
172+
"q = 15\n",
173+
"x = np.arange(10, 10000) * 10\n",
174+
"round_trip_x = approx_exp2(approx_log2(x, q), q, 0)\n",
175+
"\n",
176+
"plot_results(x, x, [round_trip_x], '2^log2(x)', logx=True, logy=True, relative=True)"
177+
],
178+
"execution_count": null,
179+
"outputs": []
180+
},
181+
{
182+
"cell_type": "code",
183+
"metadata": {
184+
"id": "nyrzI90uNH1s"
185+
},
186+
"source": [
187+
"# Approximate 2^q*sqrt(2^(x/2^q_x))\n",
188+
"def sqrt_approx_exp2(x, q_x, q):\n",
189+
" return approx_exp2(x, q_x + 1, q)\n",
190+
"\n",
191+
"q = 11\n",
192+
"q_x = 8\n",
193+
"x = np.arange(-1000, 2000)\n",
194+
"approx_exp2_x = sqrt_approx_exp2(x, q_x, q)\n",
195+
"exact = np.sqrt(np.exp2(x / 2**q_x))\n",
196+
"\n",
197+
"plot_results(x, exact, [approx_exp2_x], 'sqrt(2^x)', relative=True, log2_xscale=q_x, log2_yscale=q)\n"
198+
],
199+
"execution_count": null,
200+
"outputs": []
201+
},
202+
{
203+
"cell_type": "code",
204+
"metadata": {
205+
"id": "Kno5t4VihCTL"
206+
},
207+
"source": [
208+
"# Approximate sqrt(x) = 2^((1/2)*log2(x))\n",
209+
"def approx_sqrt(x, q):\n",
210+
" # log2(x) will never be larger than 32, for 32-bit x. So to make the result\n",
211+
" # fit in a 16-bit integer, we can make the precision 2^16/32 = 2048.\n",
212+
" q_x = 11;\n",
213+
"\n",
214+
" log2_sqrt_x = approx_log2(x, q_x - 1)\n",
215+
" return approx_exp2(log2_sqrt_x, q_x, q)\n",
216+
"\n",
217+
"q = 15\n",
218+
"x = np.arange(1, 10000)**2\n",
219+
"sqrt_x = np.sqrt(x)\n",
220+
"approx_sqrt_x = approx_sqrt(x, q)\n",
221+
"\n",
222+
"plot_results(x, sqrt_x, [approx_sqrt_x], 'sqrt(x)', log2_yscale=q, relative=True)\n"
223+
],
224+
"execution_count": null,
225+
"outputs": []
226+
},
227+
{
228+
"cell_type": "code",
229+
"metadata": {
230+
"id": "0dMecIGr92WY"
231+
},
232+
"source": [
233+
"# Approximate 2^31/sqrt(x) = 2^(-(1/2)*log2(x))\n",
234+
"def approx_reciprocal_sqrt(x):\n",
235+
" q = 15\n",
236+
" log2_sqrt_x = approx_log2(x, q - 1)\n",
237+
" return approx_exp2(-log2_sqrt_x, q, 31)\n",
238+
"\n",
239+
"x = np.arange(1, 10000)**2\n",
240+
"inv_sqrt_x = 1 / np.sqrt(x)\n",
241+
"approx_reciprocal_sqrt_x = approx_reciprocal_sqrt(x)\n",
242+
"\n",
243+
"plot_results(x, inv_sqrt_x, [approx_reciprocal_sqrt_x], '1/sqrt(x)', True, True, True, log2_yscale=31)\n"
244+
],
245+
"execution_count": null,
246+
"outputs": []
247+
},
248+
{
249+
"cell_type": "code",
250+
"metadata": {
251+
"id": "VFC9aUFcc8d7"
252+
},
253+
"source": [
254+
"# Approximate 2^32/x = 2^32*2^(-log2(x))\n",
255+
"def approx_reciprocal(x):\n",
256+
" q = 15;\n",
257+
" log2_x = approx_log2(x, q)\n",
258+
" return approx_exp2(-log2_x, q, 31)\n",
259+
"\n",
260+
"x = 1.01**np.arange(0, 2000)\n",
261+
"inv_x = 1 / x\n",
262+
"approx_inv_x = approx_reciprocal(x)\n",
263+
"# This is ~sqrt(2) times more accurate, but maybe not practical for large x.\n",
264+
"approx_inv_sqrt_x2 = approx_reciprocal_sqrt(x*x)\n",
265+
"\n",
266+
"plot_results(x, inv_x, [approx_inv_x], '1/x', True, True, log2_yscale=31, relative=True)\n",
267+
"plot_results(x, inv_x, [approx_inv_sqrt_x2], '1/x', True, True, log2_yscale=31, relative=True)\n"
268+
],
269+
"execution_count": null,
270+
"outputs": []
271+
},
272+
{
273+
"cell_type": "code",
274+
"metadata": {
275+
"id": "6BhQzLIZCcKC"
276+
},
277+
"source": [
278+
"# Approximate log2(exp2(x) + c)\n",
279+
"def approx_log2_exp2_plus_constant(x, c, q_x, q):\n",
280+
" # When x/2^q_x is large, approx_exp2 below will overflow. But when it is large\n",
281+
" # we don't need it to be very precise\n",
282+
" q_exp = 16 #np.minimum(16, 16 - np.floor(np.log2(np.maximum(x, 1))))\n",
283+
" one = 2**q_exp\n",
284+
"\n",
285+
" one_plus_exp2_x = one * c + approx_exp2(x, q_x, q_exp)\n",
286+
" # Mimic overflow of int32\n",
287+
" one_plus_exp2_x = np.mod(one_plus_exp2_x, 2**31)\n",
288+
"\n",
289+
" raw = approx_log2(one_plus_exp2_x, q, q_exp)\n",
290+
"\n",
291+
" line = rounding_shift_right(x, q_x - q)\n",
292+
"\n",
293+
" threshold = 30 - q_exp\n",
294+
" result = np.select([shift_right(x, q_x) < threshold], [raw], line)\n",
295+
" return result\n",
296+
"\n",
297+
"def approx_log2p1_exp2(x, q_x, q):\n",
298+
" return approx_log2_exp2_plus_constant(x, 1, q_x, q)\n",
299+
"\n",
300+
"def approx_log2m1_exp2(x, q_x, q):\n",
301+
" return approx_log2_exp2_plus_constant(x, -1, q_x, q)\n",
302+
"\n",
303+
"x = np.arange(-4000, 4000)*8\n",
304+
"q_x = 11\n",
305+
"q = 15\n",
306+
"\n",
307+
"exact = np.log2(np.exp2(x / 2**q_x) + 1)\n",
308+
"approx = approx_log2p1_exp2(x, q_x, q)\n",
309+
"plot_results(x, exact, [approx], 'log2(2^x + 1)', log2_xscale=q_x, log2_yscale=q)\n",
310+
"\n",
311+
"x = np.arange(1, 4000)*8\n",
312+
"exact = np.log2(np.exp2(x / 2**q_x) - 1)\n",
313+
"approx = approx_log2m1_exp2(x, q_x, q)\n",
314+
"plot_results(x, exact, [approx], 'log2(2^x - 1)', log2_xscale=q_x, log2_yscale=q)\n"
315+
],
316+
"execution_count": null,
317+
"outputs": []
318+
},
319+
{
320+
"cell_type": "code",
321+
"metadata": {
322+
"id": "G6n1u8fcUf-3"
323+
},
324+
"source": [
325+
"# Approximate logistic(x) = 1/(e^-x + 1)\n",
326+
"# = 2^log2(1/(e^-x + 1))\n",
327+
"# = 2^-log2(e^-x + 1)\n",
328+
"def approx_logistic(x, q_x, q):\n",
329+
" x2 = multiply_2x_high(x, np.round(-np.log2(np.exp(1)) * 2**14))\n",
330+
" q_exp = 11\n",
331+
" log2_d = approx_log2p1_exp2(x2, q_x - 1, q_exp)\n",
332+
" return approx_exp2(-log2_d, q_exp, q)\n",
333+
"\n",
334+
"x = np.arange(-4000, 4000)*8\n",
335+
"q_x = 11\n",
336+
"q = 15\n",
337+
"exact = 1 / (1 + np.exp(-x / 2**q_x))\n",
338+
"approx = approx_logistic(x, q_x, q)\n",
339+
"plot_results(x, exact, [approx], '1/(1 + e^-x)', log2_xscale=q_x, log2_yscale=q)"
340+
],
341+
"execution_count": null,
342+
"outputs": []
343+
},
344+
{
345+
"cell_type": "code",
346+
"metadata": {
347+
"id": "LBXXNc_8twQD"
348+
},
349+
"source": [
350+
"# Approximate tanh(x) = (e^2x - 1)/(e^2x + 1)\n",
351+
"# = 2^log2((e^2x - 1)/(e^2x + 1))\n",
352+
"# = 2^(log2(e^2x - 1) - log2(e^2x + 1))\n",
353+
"def approx_tanh(x, q_x, q):\n",
354+
" abs_x_base2 = multiply_2x_high(np.abs(x), np.round(np.log2(np.exp(1)) * 2**14))\n",
355+
" q_exp = 11\n",
356+
" log2_n = approx_log2m1_exp2(abs_x_base2, q_x - 2, q_exp)\n",
357+
" log2_d = approx_log2p1_exp2(abs_x_base2, q_x - 2, q_exp)\n",
358+
" # Saturate at int16\n",
359+
" log2_n = np.clip(log2_n, -(2**15), 2**15)\n",
360+
" log2_d = np.clip(log2_d, -(2**15), 2**15)\n",
361+
" return np.sign(x) * approx_exp2(log2_n - log2_d, q_exp, q)\n",
362+
"\n",
363+
"x = np.arange(-4000, 4000)*8\n",
364+
"q_x = 12\n",
365+
"q = 15\n",
366+
"exact = np.tanh(x / 2**q_x)\n",
367+
"approx = approx_tanh(x, q_x, q)\n",
368+
"\n",
369+
"points = 20\n",
370+
"poly_x = np.arange(0, points * 3) / points\n",
371+
"poly_y = np.tanh(poly_x)\n",
372+
"poly = np.polyfit(poly_x, poly_y, 6)\n",
373+
"approx2 = np.polyval(poly, x / 2**q_x) * 2**q\n",
374+
"\n",
375+
"\n",
376+
"plot_results(x, exact, [approx], 'tanh(x)', log2_xscale=q_x, log2_yscale=q)"
377+
],
378+
"execution_count": null,
379+
"outputs": []
380+
}
381+
]
382+
}

0 commit comments

Comments
 (0)