/* D. Monniaux CNRS / VERIMAG Demonstration of number theoretic transform (Fast Fourier transform in a finite field) in Z/(2^16+1)Z for buffer of length 2^16. FFT original code from Rosetta Code. */ #include #include #include #include #include #include #include "../clock.h" typedef uint32_t modint; typedef int32_t smodint; static modint invm(modint a0, modint b0) { smodint a=a0, b=b0, q; smodint x0 = 0, x1 = 1, t; if (b == 1) return 1; while (a > 1) { q = a / b; t = b; b = a % b; a = t; t = x0; x0 = x1 - q * x0; x1 = t; } if (x1 < 0) x1 += b0; assert((x1 * a0) % b0 == 1); return x1; } static inline modint mulm(modint a, modint b, modint m) { return (a * b) % m; } static inline modint addm(modint a, modint b, modint m) { return (a + b) % m; } static inline modint subm(modint a, modint b, modint m) { return (a + (m - b)) % m; } static inline modint powm_u(modint a, modint exponent, modint m) { modint x = 1; while (exponent > 0) { if (exponent % 2) x = mulm(x, a, m); exponent /= 2; a = mulm(a, a, m); } return x; } static inline modint powm(modint a, smodint exponent, modint m) { if (exponent == 0) return 1; else if (exponent > 0) return powm_u(a, exponent, m); else /* exponent < 0 */ return powm(invm(a, m), -exponent, m); } static void _fft(modint modulus, modint root_of_unit, modint buf[], modint out[], unsigned n, unsigned step) { if (step < n) { modint root_of_unit2 = mulm(root_of_unit, root_of_unit, modulus); _fft(modulus, root_of_unit2, out, buf, n, step * 2); _fft(modulus, root_of_unit2, out + step, buf + step, n, step * 2); modint exp = 1; for (unsigned i = 0; i < n; i += 2 * step) { modint t = mulm(exp, out[i + step], modulus); buf[i / 2] = addm(out[i], t, modulus); buf[(i + n)/2] = subm(out[i], t, modulus); exp = mulm(exp, root_of_unit, modulus); } } } void fft(modint modulus, modint root_of_unit, modint buf[], unsigned n) { modint *out = malloc(sizeof(modint) * n); memcpy(out, buf, sizeof(modint) * n); _fft(modulus, root_of_unit, buf, out, n, 1); free(out); } static void negate(modint MODULUS, modint buf[restrict], unsigned n) { for(unsigned i=0; i -x */ mulvecm(MODULUS, buf, LENGTH, invm(LENGTH, MODULUS)); #else negate(MODULUS, buf, LENGTH); #endif printf("compare = %d\n", memcmp(buf, save, LENGTH * sizeof(modint))); /* printf("buf[0] = %" PRIu64 "\n", buf[0]); printf("buf[1] = %" PRIu64 "\n", buf[1]); printf("buf[2] = %" PRIu64 "\n", buf[2]); printf("buf[3] = %" PRIu64 "\n", buf[3]); printf("buf[4] = %" PRIu64 "\n", buf[4]); */ free(buf); free(save); }