Skip to content

Commit 70e1200

Browse files
Vector, fmma and polynomial GCD improvements for mpn_mod (#2410)
* mpn_fmmamod functions and improved divrem_q1 and GCD for mpn_mod polynomials * Add mpn_mod_fmma and optimize length 1, 2 dot products
1 parent 5a8004c commit 70e1200

16 files changed

Lines changed: 1243 additions & 114 deletions

doc/source/mpn_extras.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,13 @@ Division and modular arithmetic with precomputed inverses
367367
The behavior is not exactly the same: `a` and `b` are assumed to
368368
be unshifted, and the output is unshifted.
369369

370+
.. function:: void flint_mpn_fmmamod_preinvn(mp_ptr r, mp_srcptr a1, mp_srcptr b1, mp_srcptr a2, mp_srcptr b2, mp_size_t n, mp_srcptr dnormed, mp_srcptr dinv, ulong norm)
371+
void flint_mpn_fmmamod_preinvn_2(mp_ptr r, mp_srcptr a1, mp_srcptr b1, mp_srcptr a2, mp_srcptr b2, mp_srcptr dnormed, mp_srcptr dinv, ulong norm)
372+
373+
Given ``dnormed`` containing a normalised integer `d 2^{norm}` with precomputed inverse ``dinv``
374+
provided by ``flint_mpn_preinvn``, computes `a_1 b_1 + a_2 b_2 \pmod{d}`. We require
375+
all operands to be reduced modulo `d`.
376+
370377
.. function:: void flint_mpn_mulmod_precond(mp_ptr rp, mp_srcptr apre, mp_srcptr b, mp_size_t n, mp_srcptr dnormed, mp_srcptr dinv, ulong norm)
371378

372379
Given ``dnormed`` containing a normalised integer `d 2^{norm}` with precomputed inverse ``dinv``
@@ -405,6 +412,10 @@ Division and modular arithmetic with precomputed inverses
405412
:func:`flint_mpn_mulmod_precond_precompute`
406413
given a modulus with `n` limbs.
407414

415+
.. function:: void flint_mpn_fmmamod_precond(mp_ptr rp, mp_srcptr a1pre, mp_srcptr b1, mp_srcptr a2pre, mp_srcptr b2, mp_size_t n, mp_srcptr dnormed, mp_srcptr dinv, ulong norm)
416+
417+
Analogous to :func:`flint_mpn_mulmod_precond`, but computes `a_1 b_1 + a_2 b_2` modulo `d`.
418+
408419
GCD
409420
--------------------------------------------------------------------------------
410421

doc/source/mpn_mod.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ Basic operations and arithmetic
137137
int mpn_mod_submul_ui(nn_ptr res, nn_srcptr x, ulong y, gr_ctx_t ctx)
138138
int mpn_mod_submul_si(nn_ptr res, nn_srcptr x, slong y, gr_ctx_t ctx)
139139
int mpn_mod_submul_fmpz(nn_ptr res, nn_srcptr x, const fmpz_t y, gr_ctx_t ctx)
140+
int mpn_mod_fmma(nn_ptr res, nn_srcptr x1, nn_srcptr y1, nn_srcptr x2, nn_srcptr y2, gr_ctx_t ctx)
140141
int mpn_mod_sqr(nn_ptr res, nn_srcptr x, gr_ctx_t ctx)
141142
int mpn_mod_inv(nn_ptr res, nn_srcptr x, gr_ctx_t ctx)
142143
int mpn_mod_div(nn_ptr res, nn_srcptr x, nn_srcptr y, gr_ctx_t ctx)
@@ -162,6 +163,7 @@ Vector functions
162163
int _mpn_mod_vec_mul_scalar(nn_ptr res, nn_srcptr x, slong len, nn_srcptr y, gr_ctx_t ctx)
163164
int _mpn_mod_scalar_mul_vec(nn_ptr res, nn_srcptr y, nn_srcptr x, slong len, gr_ctx_t ctx)
164165
int _mpn_mod_vec_addmul_scalar(nn_ptr res, nn_srcptr x, slong len, nn_srcptr y, gr_ctx_t ctx)
166+
int _mpn_mod_vec_submul_scalar(nn_ptr res, nn_srcptr x, slong len, nn_srcptr y, gr_ctx_t ctx);
165167
int _mpn_mod_vec_dot(nn_ptr res, nn_srcptr initial, int subtract, nn_srcptr vec1, nn_srcptr vec2, slong len, gr_ctx_t ctx)
166168
int _mpn_mod_vec_dot_rev(nn_ptr res, nn_srcptr initial, int subtract, nn_srcptr vec1, nn_srcptr vec2, slong len, gr_ctx_t ctx)
167169

@@ -268,6 +270,14 @@ Division
268270
Polynomial division with remainder implemented using the basecase
269271
algorithm with delayed reductions.
270272

273+
.. function:: int _mpn_mod_poly_divrem_q1_preinv1_fmma(nn_ptr Q, nn_ptr R, nn_srcptr A, slong lenA, nn_srcptr B, slong lenB, nn_srcptr invL, gr_ctx_t ctx);
274+
int _mpn_mod_poly_divrem_q1_preinv1_fmma_precond(nn_ptr Q, nn_ptr R, nn_srcptr A, slong lenA, nn_srcptr B, slong lenB, nn_srcptr invL, gr_ctx_t ctx);
275+
int _mpn_mod_poly_divrem_q1_preinv1_karatsuba_precond(nn_ptr Q, nn_ptr R, nn_srcptr A, slong lenA, nn_srcptr B, slong lenB, nn_srcptr invL, gr_ctx_t ctx);
276+
int _mpn_mod_poly_divrem_q1_preinv1(nn_ptr Q, nn_ptr R, nn_srcptr A, slong lenA, nn_srcptr B, slong lenB, nn_srcptr invL, gr_ctx_t ctx);
277+
278+
Algorithms for polynomial division in the special case where
279+
`lenA = lenB + 1`. Require `lenB \ge 2`.
280+
271281
.. function:: int _mpn_mod_poly_divrem(nn_ptr Q, nn_ptr R, nn_srcptr A, slong lenA, nn_srcptr B, slong lenB, gr_ctx_t ctx)
272282
int _mpn_mod_poly_div(nn_ptr Q, nn_srcptr A, slong lenA, nn_srcptr B, slong lenB, gr_ctx_t ctx)
273283

src/gr_poly/tune/cutoffs.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void _nmod_poly_mul_mid_default_mpn_ctx(nn_ptr res, slong zl, slong zh, nn_srcpt
6161
#endif
6262

6363
#if 1
64-
#define INIT_CTX fmpz_t t; fmpz_init(t); fmpz_ui_pow_ui(t, 2, bits - 1); fmpz_add_ui(t, t, 1); fmpz_nextprime(t, t, 0); GR_MUST_SUCCEED(gr_ctx_init_mpn_mod(ctx, t)); fmpz_clear(t);
64+
#define INIT_CTX fmpz_t t; fmpz_init(t); fmpz_ui_pow_ui(t, 2, bits - 1); fmpz_add_ui(t, t, 1); fmpz_nextprime(t, t, 0); GR_MUST_SUCCEED(gr_ctx_init_mpn_mod(ctx, t)); mpn_mod_ctx_set_is_field(ctx, T_TRUE); fmpz_clear(t);
6565
#define RANDCOEFF(t, ctx) fmpz_mod_rand(t, state, gr_ctx_data_as_ptr(ctx));
6666
#define STEP_BITS for (bits = 80, j = 0; bits <= 1024; bits = bits + 16, j++)
6767
#endif
@@ -212,7 +212,7 @@ void _nmod_poly_mul_mid_default_mpn_ctx(nn_ptr res, slong zl, slong zh, nn_srcpt
212212
_nmod_poly_mul_mid_default_mpn_ctx(C->coeffs, 0, B->length, A->coeffs, A->length, B->coeffs, B->length, ((nmod_t *) gr_ctx_data_ptr(ctx))[0]);
213213
#endif
214214

215-
#if 1
215+
#if 0
216216
#define INFO "divexact (basecase -> bidirectional)"
217217
#define SETUP random_input(C, state, len, ctx); \
218218
random_input(B, state, len, ctx); \
@@ -221,7 +221,7 @@ void _nmod_poly_mul_mid_default_mpn_ctx(nn_ptr res, slong zl, slong zh, nn_srcpt
221221
#define CASE_B GR_IGNORE(gr_poly_divexact_bidirectional(C, A, B, ctx));
222222
#endif
223223

224-
#if 0
224+
#if 1
225225
#define INFO "gcd"
226226
#define SETUP random_input(A, state, len, ctx); \
227227
random_input(B, state, len, ctx);

src/mpn_extras.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,10 @@ int flint_mpn_mulmod_want_precond(mp_size_t n, slong num);
895895
void flint_mpn_mulmod_precond_precompute(mp_ptr apre, mp_srcptr a, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm);
896896
void flint_mpn_mulmod_precond(mp_ptr rp, mp_srcptr apre, mp_srcptr b, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm);
897897

898+
void flint_mpn_fmmamod_preinvn(mp_ptr r, mp_srcptr a, mp_srcptr b, mp_srcptr e, mp_srcptr f, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm);
899+
void flint_mpn_fmmamod_preinvn_2(mp_ptr r, mp_srcptr a, mp_srcptr b, mp_srcptr e, mp_srcptr f, mp_srcptr d, mp_srcptr dinv, ulong norm);
900+
void flint_mpn_fmmamod_precond(mp_ptr rp, mp_srcptr apre1, mp_srcptr b1, mp_srcptr apre2, mp_srcptr b2, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm);
901+
898902
int flint_mpn_mulmod_2expp1_basecase(mp_ptr xp, mp_srcptr yp, mp_srcptr zp, int c, flint_bitcnt_t b, mp_ptr tp);
899903

900904
/* miscellaneous *************************************************************/

src/mpn_extras/mulmod_precond.c

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,23 @@ flint_mpn_mulmod_precond(mp_ptr rp, mp_srcptr apre, mp_srcptr b, mp_size_t n, mp
112112
*/
113113

114114
mp_ptr tmp;
115-
mp_limb_t cy;
115+
mp_limb_t cy, cy1, cy2;
116116
slong i, rn;
117117
TMP_INIT;
118118

119119
TMP_START;
120120
tmp = TMP_ALLOC((n + 2) * sizeof(mp_limb_t));
121121

122-
tmp[n] = mpn_mul_1(tmp, apre, n, b[0]);
123-
tmp[n + 1] = 0;
122+
cy1 = mpn_mul_1(tmp, apre, n, b[0]);
123+
cy2 = 0;
124124
for (i = 1; i < n; i++)
125125
{
126126
cy = mpn_addmul_1(tmp, apre + i * n, n, b[i]);
127-
add_ssaaaa(tmp[n + 1], tmp[n], tmp[n + 1], tmp[n], 0, cy);
127+
add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
128128
}
129129

130+
tmp[n] = cy1;
131+
tmp[n + 1] = cy2;
130132
rn = (n + 2) - (tmp[n + 1] == 0);
131133

132134
#if 0
@@ -143,3 +145,77 @@ flint_mpn_mulmod_precond(mp_ptr rp, mp_srcptr apre, mp_srcptr b, mp_size_t n, mp
143145
TMP_END;
144146
}
145147

148+
void
149+
flint_mpn_fmmamod_precond(mp_ptr rp, mp_srcptr apre1, mp_srcptr b1, mp_srcptr apre2, mp_srcptr b2, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm)
150+
{
151+
mp_ptr tmp;
152+
mp_limb_t cy, cy1, cy2;
153+
slong i, rn;
154+
TMP_INIT;
155+
156+
/* Something like this if we want a special case for n = 2 */
157+
/*
158+
if (n == 2)
159+
{
160+
ulong tmp[4];
161+
ulong ump[4];
162+
163+
FLINT_MPN_MUL_2X1(tmp[2], tmp[1], tmp[0], apre1[1], apre1[0], b1[0]);
164+
FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre1[3], apre1[2], b1[1]);
165+
add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);
166+
FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre2[3], apre2[2], b2[0]);
167+
add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);
168+
FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre2[3], apre2[2], b2[1]);
169+
add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);
170+
171+
rn = (n + 2) - (tmp[n + 1] == 0);
172+
flint_mpn_mod_preinv1(tmp, rn, d, n, dinv[n - 1]);
173+
174+
if (norm)
175+
{
176+
rp[0] = (tmp[0] >> norm) | (tmp[1] << (FLINT_BITS - norm));
177+
rp[1] = (tmp[1] >> norm);
178+
}
179+
else
180+
{
181+
rp[0] = tmp[0];
182+
rp[1] = tmp[1];
183+
}
184+
185+
return;
186+
}
187+
*/
188+
189+
TMP_START;
190+
tmp = TMP_ALLOC((n + 2) * sizeof(mp_limb_t));
191+
192+
cy1 = mpn_mul_1(tmp, apre1, n, b1[0]);
193+
cy2 = 0;
194+
for (i = 1; i < n; i++)
195+
{
196+
cy = mpn_addmul_1(tmp, apre1 + i * n, n, b1[i]);
197+
add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
198+
}
199+
for (i = 0; i < n; i++)
200+
{
201+
cy = mpn_addmul_1(tmp, apre2 + i * n, n, b2[i]);
202+
add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
203+
}
204+
205+
tmp[n] = cy1;
206+
tmp[n + 1] = cy2;
207+
rn = (n + 2) - (tmp[n + 1] == 0);
208+
209+
#if 0
210+
flint_mpn_mod_preinvn(tmp, tmp, rn, d, n, dinv);
211+
#else
212+
flint_mpn_mod_preinv1(tmp, rn, d, n, dinv[n - 1]);
213+
#endif
214+
215+
if (norm == 0)
216+
flint_mpn_copyi(rp, tmp, n);
217+
else
218+
mpn_rshift(rp, tmp, n, norm);
219+
220+
TMP_END;
221+
}

src/mpn_extras/mulmod_preinvn.c

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,136 @@ void flint_mpn_mulmod_preinvn_2(mp_ptr r,
160160
r[1] = r1;
161161
}
162162
}
163+
164+
void flint_mpn_fmmamod_preinvn(mp_ptr r,
165+
mp_srcptr a, mp_srcptr b,
166+
mp_srcptr e, mp_srcptr f,
167+
mp_size_t n,
168+
mp_srcptr d, mp_srcptr dinv, ulong norm)
169+
{
170+
mp_ptr t, u;
171+
ulong cy;
172+
TMP_INIT;
173+
174+
TMP_START;
175+
t = TMP_ALLOC((7 * n) * sizeof(mp_limb_t));
176+
u = t + (5 * n);
177+
178+
if (a == b)
179+
flint_mpn_sqr(t, a, n);
180+
else
181+
flint_mpn_mul_n(t, a, b, n);
182+
183+
if (e == f)
184+
flint_mpn_sqr(u, e, n);
185+
else
186+
flint_mpn_mul_n(u, e, f, n);
187+
188+
if (norm)
189+
{
190+
mpn_add_n(t, t, u, 2 * n);
191+
cy = mpn_lshift(t, t, 2 * n, norm);
192+
}
193+
else
194+
{
195+
cy = mpn_add_n(t, t, u, 2 * n);
196+
}
197+
198+
if (cy != 0 || mpn_cmp(t + n, d, n) >= 0)
199+
{
200+
mpn_sub_n(t + n, t + n, d, n);
201+
}
202+
203+
flint_mpn_mul_or_mulhigh_n(t + 3 * n, t + n, dinv, n);
204+
mpn_add_n(t + 4 * n, t + 4 * n, t + n, n);
205+
206+
/* note: we rely on the fact that mul_or_mullow_n actually
207+
writes at least n + 1 limbs */
208+
flint_mpn_mul_or_mullow_n(t + 2 * n, t + 4 * n, d, n);
209+
cy = t[n] - t[3 * n] - mpn_sub_n(r, t, t + 2 * n, n);
210+
211+
while (cy > 0)
212+
cy -= mpn_sub_n(r, r, d, n);
213+
214+
if (mpn_cmp(r, d, n) >= 0)
215+
mpn_sub_n(r, r, d, n);
216+
217+
FLINT_ASSERT(mpn_cmp(r, d, n) < 0);
218+
219+
if (norm)
220+
mpn_rshift(r, r, n, norm);
221+
222+
TMP_END;
223+
}
224+
225+
void flint_mpn_fmmamod_preinvn_2(mp_ptr r,
226+
mp_srcptr a, mp_srcptr b,
227+
mp_srcptr e, mp_srcptr f,
228+
mp_srcptr d, mp_srcptr dinv, ulong norm)
229+
{
230+
mp_limb_t cy, b0, b1, r0, r1;
231+
mp_limb_t f0, f1;
232+
mp_limb_t t[10], u[4];
233+
234+
if (norm)
235+
{
236+
/* mpn_lshift(b, b, n, norm) */
237+
b0 = (b[0] << norm);
238+
b1 = (b[1] << norm) | (b[0] >> (FLINT_BITS - norm));
239+
f0 = (f[0] << norm);
240+
f1 = (f[1] << norm) | (f[0] >> (FLINT_BITS - norm));
241+
}
242+
else
243+
{
244+
b0 = b[0];
245+
b1 = b[1];
246+
f0 = f[0];
247+
f1 = f[1];
248+
}
249+
250+
/* mpn_mul_n(t, a, b, n) */
251+
FLINT_MPN_MUL_2X2(t[3], t[2], t[1], t[0], a[1], a[0], b1, b0);
252+
/* mpn_mul_n(u, e, f, n) */
253+
FLINT_MPN_MUL_2X2(u[3], u[2], u[1], u[0], e[1], e[0], f1, f0);
254+
add_sssssaaaaaaaaaa(cy, t[3], t[2], t[1], t[0],
255+
0, t[3], t[2], t[1], t[0],
256+
0, u[3], u[2], u[1], u[0]);
257+
if (cy || mpn_cmp(t + 2, d, 2) >= 0)
258+
sub_ddmmss(t[3], t[2], t[3], t[2], d[1], d[0]);
259+
260+
/* mpn_mul_n(t + 3*n, t + n, dinv, n) */
261+
FLINT_MPN_MUL_2X2(t[9], t[8], t[7], t[6], t[3], t[2], dinv[1], dinv[0]);
262+
263+
/* mpn_add_n(t + 4*n, t + 4*n, t + n, n) */
264+
add_ssaaaa(t[9], t[8], t[9], t[8], t[3], t[2]);
265+
266+
/* mpn_mul_n(t + 2*n, t + 4*n, d, n) */
267+
FLINT_MPN_MUL_3P2X2(t[6], t[5], t[4], t[9], t[8], d[1], d[0]);
268+
269+
/* cy = t[n] - t[3*n] - mpn_sub_n(r, t, t + 2*n, n) */
270+
sub_dddmmmsss(cy, r1, r0, t[2], t[1], t[0], t[6], t[5], t[4]);
271+
272+
while (cy > 0)
273+
{
274+
/* cy -= mpn_sub_n(r, r, d, n) */
275+
sub_dddmmmsss(cy, r1, r0, cy, r1, r0, 0, d[1], d[0]);
276+
}
277+
278+
if ((r1 > d[1]) || (r1 == d[1] && r0 >= d[0]))
279+
{
280+
/* mpn_sub_n(r, r, d, n) */
281+
sub_ddmmss(r1, r0, r1, r0, d[1], d[0]);
282+
}
283+
284+
if (norm)
285+
{
286+
r[0] = (r0 >> norm) | (r1 << (FLINT_BITS - norm));
287+
r[1] = (r1 >> norm);
288+
}
289+
else
290+
{
291+
r[0] = r0;
292+
r[1] = r1;
293+
}
294+
}
295+

src/mpn_extras/test/main.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "t-divides.c"
1616
#include "t-divrem_preinv1.c"
1717
#include "t-divrem_preinvn.c"
18+
#include "t-fmmamod_precond.c"
19+
#include "t-fmmamod_preinvn.c"
1820
#include "t-fmms1.c"
1921
#include "t-gcd_full.c"
2022
#include "t-mod_preinvn.c"
@@ -44,6 +46,8 @@ test_struct tests[] =
4446
TEST_FUNCTION(flint_mpn_divides),
4547
TEST_FUNCTION(flint_mpn_divrem_preinv1),
4648
TEST_FUNCTION(flint_mpn_divrem_preinvn),
49+
TEST_FUNCTION(flint_mpn_fmmamod_precond),
50+
TEST_FUNCTION(flint_mpn_fmmamod_preinvn),
4751
TEST_FUNCTION(flint_mpn_fmms1),
4852
TEST_FUNCTION(flint_mpn_gcd_full),
4953
TEST_FUNCTION(flint_mpn_mod_preinvn),

0 commit comments

Comments
 (0)