1+ {
2+ "nbformat" : 4 ,
3+ "nbformat_minor" : 0 ,
4+ "metadata" : {
5+ "colab" : {
6+ "provenance" : []
7+ },
8+ "kernelspec" : {
9+ "name" : " python3" ,
10+ "display_name" : " Python 3"
11+ }
12+ },
13+ "cells" : [
14+ {
15+ "cell_type" : " code" ,
16+ "metadata" : {
17+ "id" : " r1XiiUQGUjpx"
18+ },
19+ "source" : [
20+ " import numpy as np\n " ,
21+ " import matplotlib as mpl\n " ,
22+ " import matplotlib.pyplot as plt\n " ,
23+ " \n " ,
24+ " # Many architectures have shifts where the right-hand-side is signed. A negative\n " ,
25+ " # RHS is the same as a positive shift in the other direction.\n " ,
26+ " def shift_right(x, y):\n " ,
27+ " return np.floor(x / 2**y)\n " ,
28+ " def shift_left(x, y):\n " ,
29+ " return np.floor(x * 2**y)\n " ,
30+ " def rounding_shift_right(x, y):\n " ,
31+ " return np.round(x / 2**y)\n " ,
32+ " def rounding_shift_left(x, y):\n " ,
33+ " return np.round(x * 2**y)\n " ,
34+ " \n " ,
35+ " def bitwise_and(x, y):\n " ,
36+ " return np.mod(x, y + 1)\n " ,
37+ " \n " ,
38+ " # This is sqrdmulh on ARM\n " ,
39+ " def multiply_2x_high(x, y):\n " ,
40+ " return rounding_shift_right(x * y, 15)\n " ,
41+ " \n " ,
42+ " def relative_error(x, y):\n " ,
43+ " return (x - y) / (np.maximum(x, y) + 1e-3)\n " ,
44+ " \n " ,
45+ " def plot_results(x, exact, approxs, title, logx = False, logy = False, relative = False, log2_xscale = 0, log2_yscale = 0):\n " ,
46+ " fig, [p1, p2] = plt.subplots(2, 1)\n " ,
47+ " \n " ,
48+ " p1.set_xlabel('x')\n " ,
49+ " if logx:\n " ,
50+ " p1.set_xscale('log')\n " ,
51+ " p1.set_ylabel(title)\n " ,
52+ " if logy:\n " ,
53+ " p1.set_yscale('log')\n " ,
54+ " \n " ,
55+ " xscale = 2**log2_xscale\n " ,
56+ " yscale = 2**log2_yscale\n " ,
57+ " \n " ,
58+ " exact = np.round(exact*yscale)/yscale\n " ,
59+ " \n " ,
60+ " p1.plot(x/xscale, exact)\n " ,
61+ " for approx in approxs:\n " ,
62+ " p1.plot(x/xscale, approx/yscale)\n " ,
63+ " \n " ,
64+ " p2.set_xlabel('x')\n " ,
65+ " if logx:\n " ,
66+ " p2.set_xscale('log')\n " ,
67+ " \n " ,
68+ " p2.set_ylabel('relative error' if relative else 'error')\n " ,
69+ " for approx in approxs:\n " ,
70+ " p2.plot(x/xscale, relative_error(approx/yscale, exact) if relative else approx/yscale - exact)\n " ,
71+ " \n " ,
72+ " def eval_poly(x, p, q):\n " ,
73+ " x1 = rounding_shift_left(x, 15 - q)\n " ,
74+ " y = p[0]\n " ,
75+ " xi = x1\n " ,
76+ " for i in p[1:]:\n " ,
77+ " y = y + multiply_2x_high(i, xi)\n " ,
78+ " xi = multiply_2x_high(xi, x1)\n " ,
79+ " return rounding_shift_right(y, 15 - q)\n " ,
80+ " \n " ,
81+ " points = 6\n " ,
82+ " degree = 3\n " ,
83+ " log2_poly_x = np.arange(points, 2 * points + 1) / points\n " ,
84+ " log2_poly_y = np.log2(log2_poly_x)\n " ,
85+ " log2_poly = np.polyfit(log2_poly_x - 1, log2_poly_y, degree)\n " ,
86+ " \n " ,
87+ " exp2_poly_x = np.arange(points, 2 * points + 1) / points\n " ,
88+ " exp2_poly_y = np.exp2(exp2_poly_x - 1) - 1\n " ,
89+ " exp2_poly = np.polyfit(exp2_poly_x - 1, exp2_poly_y, degree)\n " ,
90+ " \n " ,
91+ " log2_poly = log2_poly[::-1]\n " ,
92+ " exp2_poly = exp2_poly[::-1]\n " ,
93+ " \n " ,
94+ " print(log2_poly)\n " ,
95+ " print(exp2_poly)\n " ,
96+ " \n " ,
97+ " log2_poly = np.round(log2_poly * 2**15)\n " ,
98+ " exp2_poly = np.round(exp2_poly * 2**15)\n " ,
99+ " exp2_poly[0] = 0\n " ,
100+ " \n " ,
101+ " print(log2_poly)\n " ,
102+ " print(exp2_poly)"
103+ ],
104+ "execution_count" : null ,
105+ "outputs" : []
106+ },
107+ {
108+ "cell_type" : " code" ,
109+ "metadata" : {
110+ "id" : " 1xjo4hIEo_z5"
111+ },
112+ "source" : [
113+ " # Approximate N*log2(x*2^q_x), where N = 2^q, and the intermediate computations are\n " ,
114+ " # restricted to be integers.\n " ,
115+ " def approx_log2(x, q, q_x = 0):\n " ,
116+ " # This can be computed with count_leading_zeros\n " ,
117+ " floor_log2_x = np.select([x > 0], [np.floor(np.log2(x))], [-1])\n " ,
118+ " \n " ,
119+ " # We've computed log2(x*2^q_x) = log2(x) + q_x. Subtract that offset now\n " ,
120+ " # before multiplying by the result quantization.\n " ,
121+ " result = shift_left(floor_log2_x - q_x, q)\n " ,
122+ " \n " ,
123+ " frac = bitwise_and(shift_right(x, floor_log2_x - q), 2**q - 1)\n " ,
124+ " \n " ,
125+ " return result + eval_poly(frac, log2_poly, q)\n " ,
126+ " \n " ,
127+ " x = np.arange(1, 10000)\n " ,
128+ " q = 15\n " ,
129+ " q_x = 2\n " ,
130+ " log2_x = np.log2(x / 2**q_x)\n " ,
131+ " approx_log2_x = approx_log2(x, q, q_x)\n " ,
132+ " \n " ,
133+ " plot_results(x, log2_x, [approx_log2_x], 'log2(x)', logx=True, log2_xscale=q_x, log2_yscale=q)"
134+ ],
135+ "execution_count" : null ,
136+ "outputs" : []
137+ },
138+ {
139+ "cell_type" : " code" ,
140+ "metadata" : {
141+ "id" : " 6uJN5muLsLdE"
142+ },
143+ "source" : [
144+ " \n " ,
145+ " # Approximate 2^(x/2^q_x)*2^q\n " ,
146+ " def approx_exp2(x, q_x, q):\n " ,
147+ " int_part = shift_right(x, q_x)\n " ,
148+ " frac_part = x - shift_left(int_part, q_x)\n " ,
149+ " \n " ,
150+ " frac_part = eval_poly(frac_part, exp2_poly, q_x)\n " ,
151+ " \n " ,
152+ " exp_int_part = shift_left(1, int_part + q)\n " ,
153+ " return exp_int_part + rounding_shift_right(exp_int_part * frac_part, q_x)\n " ,
154+ " \n " ,
155+ " q_x = 10\n " ,
156+ " q = 15\n " ,
157+ " x = np.arange(-4000, 2000)\n " ,
158+ " approx_exp2_x = approx_exp2(x, q_x, q)\n " ,
159+ " exact = np.exp2(x / 2**q_x)\n " ,
160+ " \n " ,
161+ " plot_results(x, exact, [approx_exp2_x], '2^x', False, True, relative=True, log2_xscale=q_x, log2_yscale=q)\n "
162+ ],
163+ "execution_count" : null ,
164+ "outputs" : []
165+ },
166+ {
167+ "cell_type" : " code" ,
168+ "metadata" : {
169+ "id" : " 5BP-edzCmNBi"
170+ },
171+ "source" : [
172+ " q = 15\n " ,
173+ " x = np.arange(10, 10000) * 10\n " ,
174+ " round_trip_x = approx_exp2(approx_log2(x, q), q, 0)\n " ,
175+ " \n " ,
176+ " plot_results(x, x, [round_trip_x], '2^log2(x)', logx=True, logy=True, relative=True)"
177+ ],
178+ "execution_count" : null ,
179+ "outputs" : []
180+ },
181+ {
182+ "cell_type" : " code" ,
183+ "metadata" : {
184+ "id" : " nyrzI90uNH1s"
185+ },
186+ "source" : [
187+ " # Approximate 2^q*sqrt(2^(x/2^q_x))\n " ,
188+ " def sqrt_approx_exp2(x, q_x, q):\n " ,
189+ " return approx_exp2(x, q_x + 1, q)\n " ,
190+ " \n " ,
191+ " q = 11\n " ,
192+ " q_x = 8\n " ,
193+ " x = np.arange(-1000, 2000)\n " ,
194+ " approx_exp2_x = sqrt_approx_exp2(x, q_x, q)\n " ,
195+ " exact = np.sqrt(np.exp2(x / 2**q_x))\n " ,
196+ " \n " ,
197+ " plot_results(x, exact, [approx_exp2_x], 'sqrt(2^x)', relative=True, log2_xscale=q_x, log2_yscale=q)\n "
198+ ],
199+ "execution_count" : null ,
200+ "outputs" : []
201+ },
202+ {
203+ "cell_type" : " code" ,
204+ "metadata" : {
205+ "id" : " Kno5t4VihCTL"
206+ },
207+ "source" : [
208+ " # Approximate sqrt(x) = 2^((1/2)*log2(x))\n " ,
209+ " def approx_sqrt(x, q):\n " ,
210+ " # log2(x) will never be larger than 32, for 32-bit x. So to make the result\n " ,
211+ " # fit in a 16-bit integer, we can make the precision 2^16/32 = 2048.\n " ,
212+ " q_x = 11;\n " ,
213+ " \n " ,
214+ " log2_sqrt_x = approx_log2(x, q_x - 1)\n " ,
215+ " return approx_exp2(log2_sqrt_x, q_x, q)\n " ,
216+ " \n " ,
217+ " q = 15\n " ,
218+ " x = np.arange(1, 10000)**2\n " ,
219+ " sqrt_x = np.sqrt(x)\n " ,
220+ " approx_sqrt_x = approx_sqrt(x, q)\n " ,
221+ " \n " ,
222+ " plot_results(x, sqrt_x, [approx_sqrt_x], 'sqrt(x)', log2_yscale=q, relative=True)\n "
223+ ],
224+ "execution_count" : null ,
225+ "outputs" : []
226+ },
227+ {
228+ "cell_type" : " code" ,
229+ "metadata" : {
230+ "id" : " 0dMecIGr92WY"
231+ },
232+ "source" : [
233+ " # Approximate 2^31/sqrt(x) = 2^(-(1/2)*log2(x))\n " ,
234+ " def approx_reciprocal_sqrt(x):\n " ,
235+ " q = 15\n " ,
236+ " log2_sqrt_x = approx_log2(x, q - 1)\n " ,
237+ " return approx_exp2(-log2_sqrt_x, q, 31)\n " ,
238+ " \n " ,
239+ " x = np.arange(1, 10000)**2\n " ,
240+ " inv_sqrt_x = 1 / np.sqrt(x)\n " ,
241+ " approx_reciprocal_sqrt_x = approx_reciprocal_sqrt(x)\n " ,
242+ " \n " ,
243+ " plot_results(x, inv_sqrt_x, [approx_reciprocal_sqrt_x], '1/sqrt(x)', True, True, True, log2_yscale=31)\n "
244+ ],
245+ "execution_count" : null ,
246+ "outputs" : []
247+ },
248+ {
249+ "cell_type" : " code" ,
250+ "metadata" : {
251+ "id" : " VFC9aUFcc8d7"
252+ },
253+ "source" : [
254+ " # Approximate 2^32/x = 2^32*2^(-log2(x))\n " ,
255+ " def approx_reciprocal(x):\n " ,
256+ " q = 15;\n " ,
257+ " log2_x = approx_log2(x, q)\n " ,
258+ " return approx_exp2(-log2_x, q, 31)\n " ,
259+ " \n " ,
260+ " x = 1.01**np.arange(0, 2000)\n " ,
261+ " inv_x = 1 / x\n " ,
262+ " approx_inv_x = approx_reciprocal(x)\n " ,
263+ " # This is ~sqrt(2) times more accurate, but maybe not practical for large x.\n " ,
264+ " approx_inv_sqrt_x2 = approx_reciprocal_sqrt(x*x)\n " ,
265+ " \n " ,
266+ " plot_results(x, inv_x, [approx_inv_x], '1/x', True, True, log2_yscale=31, relative=True)\n " ,
267+ " plot_results(x, inv_x, [approx_inv_sqrt_x2], '1/x', True, True, log2_yscale=31, relative=True)\n "
268+ ],
269+ "execution_count" : null ,
270+ "outputs" : []
271+ },
272+ {
273+ "cell_type" : " code" ,
274+ "metadata" : {
275+ "id" : " 6BhQzLIZCcKC"
276+ },
277+ "source" : [
278+ " # Approximate log2(exp2(x) + c)\n " ,
279+ " def approx_log2_exp2_plus_constant(x, c, q_x, q):\n " ,
280+ " # When x/2^q_x is large, approx_exp2 below will overflow. But when it is large\n " ,
281+ " # we don't need it to be very precise\n " ,
282+ " q_exp = 16 #np.minimum(16, 16 - np.floor(np.log2(np.maximum(x, 1))))\n " ,
283+ " one = 2**q_exp\n " ,
284+ " \n " ,
285+ " one_plus_exp2_x = one * c + approx_exp2(x, q_x, q_exp)\n " ,
286+ " # Mimic overflow of int32\n " ,
287+ " one_plus_exp2_x = np.mod(one_plus_exp2_x, 2**31)\n " ,
288+ " \n " ,
289+ " raw = approx_log2(one_plus_exp2_x, q, q_exp)\n " ,
290+ " \n " ,
291+ " line = rounding_shift_right(x, q_x - q)\n " ,
292+ " \n " ,
293+ " threshold = 30 - q_exp\n " ,
294+ " result = np.select([shift_right(x, q_x) < threshold], [raw], line)\n " ,
295+ " return result\n " ,
296+ " \n " ,
297+ " def approx_log2p1_exp2(x, q_x, q):\n " ,
298+ " return approx_log2_exp2_plus_constant(x, 1, q_x, q)\n " ,
299+ " \n " ,
300+ " def approx_log2m1_exp2(x, q_x, q):\n " ,
301+ " return approx_log2_exp2_plus_constant(x, -1, q_x, q)\n " ,
302+ " \n " ,
303+ " x = np.arange(-4000, 4000)*8\n " ,
304+ " q_x = 11\n " ,
305+ " q = 15\n " ,
306+ " \n " ,
307+ " exact = np.log2(np.exp2(x / 2**q_x) + 1)\n " ,
308+ " approx = approx_log2p1_exp2(x, q_x, q)\n " ,
309+ " plot_results(x, exact, [approx], 'log2(2^x + 1)', log2_xscale=q_x, log2_yscale=q)\n " ,
310+ " \n " ,
311+ " x = np.arange(1, 4000)*8\n " ,
312+ " exact = np.log2(np.exp2(x / 2**q_x) - 1)\n " ,
313+ " approx = approx_log2m1_exp2(x, q_x, q)\n " ,
314+ " plot_results(x, exact, [approx], 'log2(2^x - 1)', log2_xscale=q_x, log2_yscale=q)\n "
315+ ],
316+ "execution_count" : null ,
317+ "outputs" : []
318+ },
319+ {
320+ "cell_type" : " code" ,
321+ "metadata" : {
322+ "id" : " G6n1u8fcUf-3"
323+ },
324+ "source" : [
325+ " # Approximate logistic(x) = 1/(e^-x + 1)\n " ,
326+ " # = 2^log2(1/(e^-x + 1))\n " ,
327+ " # = 2^-log2(e^-x + 1)\n " ,
328+ " def approx_logistic(x, q_x, q):\n " ,
329+ " x2 = multiply_2x_high(x, np.round(-np.log2(np.exp(1)) * 2**14))\n " ,
330+ " q_exp = 11\n " ,
331+ " log2_d = approx_log2p1_exp2(x2, q_x - 1, q_exp)\n " ,
332+ " return approx_exp2(-log2_d, q_exp, q)\n " ,
333+ " \n " ,
334+ " x = np.arange(-4000, 4000)*8\n " ,
335+ " q_x = 11\n " ,
336+ " q = 15\n " ,
337+ " exact = 1 / (1 + np.exp(-x / 2**q_x))\n " ,
338+ " approx = approx_logistic(x, q_x, q)\n " ,
339+ " plot_results(x, exact, [approx], '1/(1 + e^-x)', log2_xscale=q_x, log2_yscale=q)"
340+ ],
341+ "execution_count" : null ,
342+ "outputs" : []
343+ },
344+ {
345+ "cell_type" : " code" ,
346+ "metadata" : {
347+ "id" : " LBXXNc_8twQD"
348+ },
349+ "source" : [
350+ " # Approximate tanh(x) = (e^2x - 1)/(e^2x + 1)\n " ,
351+ " # = 2^log2((e^2x - 1)/(e^2x + 1))\n " ,
352+ " # = 2^(log2(e^2x - 1) - log2(e^2x + 1))\n " ,
353+ " def approx_tanh(x, q_x, q):\n " ,
354+ " abs_x_base2 = multiply_2x_high(np.abs(x), np.round(np.log2(np.exp(1)) * 2**14))\n " ,
355+ " q_exp = 11\n " ,
356+ " log2_n = approx_log2m1_exp2(abs_x_base2, q_x - 2, q_exp)\n " ,
357+ " log2_d = approx_log2p1_exp2(abs_x_base2, q_x - 2, q_exp)\n " ,
358+ " # Saturate at int16\n " ,
359+ " log2_n = np.clip(log2_n, -(2**15), 2**15)\n " ,
360+ " log2_d = np.clip(log2_d, -(2**15), 2**15)\n " ,
361+ " return np.sign(x) * approx_exp2(log2_n - log2_d, q_exp, q)\n " ,
362+ " \n " ,
363+ " x = np.arange(-4000, 4000)*8\n " ,
364+ " q_x = 12\n " ,
365+ " q = 15\n " ,
366+ " exact = np.tanh(x / 2**q_x)\n " ,
367+ " approx = approx_tanh(x, q_x, q)\n " ,
368+ " \n " ,
369+ " points = 20\n " ,
370+ " poly_x = np.arange(0, points * 3) / points\n " ,
371+ " poly_y = np.tanh(poly_x)\n " ,
372+ " poly = np.polyfit(poly_x, poly_y, 6)\n " ,
373+ " approx2 = np.polyval(poly, x / 2**q_x) * 2**q\n " ,
374+ " \n " ,
375+ " \n " ,
376+ " plot_results(x, exact, [approx], 'tanh(x)', log2_xscale=q_x, log2_yscale=q)"
377+ ],
378+ "execution_count" : null ,
379+ "outputs" : []
380+ }
381+ ]
382+ }
0 commit comments