@@ -482,8 +482,20 @@ class KShortest {
482482
483483template <typename real, typename real2, int MAXSHIFT, int MAXBUTWIDTH>
484484void SleefDFTXX<real, real2, MAXSHIFT, MAXBUTWIDTH>::measurementRun (real *d, const real *s, const vector<Action> &path, uint64_t niter) {
485- const int tn = getThreadNum ();
486- real *t[] = { x1[tn], x0[tn], d };
485+ auto tid = this_thread::get_id ();
486+ real *t[] = { nullptr , nullptr , d };
487+ {
488+ unique_lock lock (mtx);
489+ if (xn.count (tid) != 0 ) {
490+ auto e = xn[tid];
491+ t[0 ] = e.first ;
492+ t[1 ] = e.second ;
493+ } else {
494+ t[0 ] = (real *)Sleef_malloc (sizeof (real) * 2 * (1L << log2len));
495+ t[1 ] = (real *)Sleef_malloc (sizeof (real) * 2 * (1L << log2len));
496+ xn[tid] = pair<real *, real *>{ t[0 ], t[1 ] };
497+ }
498+ }
487499
488500 for (uint64_t i=0 ;i<niter;i++) {
489501 const real *lb = s;
@@ -1090,14 +1102,6 @@ SleefDFTXX<real, real2, MAXSHIFT, MAXBUTWIDTH>::SleefDFTXX(uint32_t n, const rea
10901102 for (int level = log2len;level >= 1 ;level--) {
10911103 perm[level] = (uint32_t *)Sleef_malloc (sizeof (uint32_t ) * ((1 << log2len) + 8 ));
10921104 }
1093-
1094- x0 = (real **)malloc (sizeof (real *) * nThread);
1095- x1 = (real **)malloc (sizeof (real *) * nThread);
1096-
1097- for (int i=0 ;i<nThread;i++) {
1098- x0[i] = (real *)Sleef_malloc (sizeof (real) * 2 * n);
1099- x1[i] = (real *)Sleef_malloc (sizeof (real) * 2 * n);
1100- }
11011105
11021106 if ((mode & SLEEF_MODE_REAL) != 0 ) {
11031107 rtCoef0 = (real *)Sleef_malloc (sizeof (real) * n);
@@ -1335,9 +1339,21 @@ void SleefDFTXX<real, real2, MAXSHIFT, MAXBUTWIDTH>::execute(const real *s0, rea
13351339
13361340 //
13371341
1338- const int tn = getThreadNum ();
1339- real *t[] = { x1[tn], x0[tn], d };
1340-
1342+ auto tid = this_thread::get_id ();
1343+ real *t[] = { nullptr , nullptr , d };
1344+ {
1345+ unique_lock lock (mtx);
1346+ if (xn.count (tid) != 0 ) {
1347+ auto e = xn[tid];
1348+ t[0 ] = e.first ;
1349+ t[1 ] = e.second ;
1350+ } else {
1351+ t[0 ] = (real *)Sleef_malloc (sizeof (real) * 2 * (1L << log2len));
1352+ t[1 ] = (real *)Sleef_malloc (sizeof (real) * 2 * (1L << log2len));
1353+ xn[tid] = pair<real *, real *>{ t[0 ], t[1 ] };
1354+ }
1355+ }
1356+
13411357 const real *lb = s;
13421358 int nb = 0 ;
13431359
0 commit comments