Skip to content

Commit 459cd73

Browse files
committed
The same instance of SleefDFT can be called from multiple threads simultaneously.
1 parent 3993f71 commit 459cd73

File tree

3 files changed

+35
-40
lines changed

3 files changed

+35
-40
lines changed

src/dft/dft.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,20 @@ class KShortest {
482482

483483
template<typename real, typename real2, int MAXSHIFT, int MAXBUTWIDTH>
484484
void SleefDFTXX<real, real2, MAXSHIFT, MAXBUTWIDTH>::measurementRun(real *d, const real *s, const vector<Action> &path, uint64_t niter) {
485-
const int tn = getThreadNum();
486-
real *t[] = { x1[tn], x0[tn], d };
485+
auto tid = this_thread::get_id();
486+
real *t[] = { nullptr, nullptr, d };
487+
{
488+
unique_lock lock(mtx);
489+
if (xn.count(tid) != 0) {
490+
auto e = xn[tid];
491+
t[0] = e.first;
492+
t[1] = e.second;
493+
} else {
494+
t[0] = (real *)Sleef_malloc(sizeof(real) * 2 * (1L << log2len));
495+
t[1] = (real *)Sleef_malloc(sizeof(real) * 2 * (1L << log2len));
496+
xn[tid] = pair<real *, real *>{ t[0], t[1] };
497+
}
498+
}
487499

488500
for(uint64_t i=0;i<niter;i++) {
489501
const real *lb = s;
@@ -1090,14 +1102,6 @@ SleefDFTXX<real, real2, MAXSHIFT, MAXBUTWIDTH>::SleefDFTXX(uint32_t n, const rea
10901102
for(int level = log2len;level >= 1;level--) {
10911103
perm[level] = (uint32_t *)Sleef_malloc(sizeof(uint32_t) * ((1 << log2len) + 8));
10921104
}
1093-
1094-
x0 = (real **)malloc(sizeof(real *) * nThread);
1095-
x1 = (real **)malloc(sizeof(real *) * nThread);
1096-
1097-
for(int i=0;i<nThread;i++) {
1098-
x0[i] = (real *)Sleef_malloc(sizeof(real) * 2 * n);
1099-
x1[i] = (real *)Sleef_malloc(sizeof(real) * 2 * n);
1100-
}
11011105

11021106
if ((mode & SLEEF_MODE_REAL) != 0) {
11031107
rtCoef0 = (real *)Sleef_malloc(sizeof(real) * n);
@@ -1335,9 +1339,21 @@ void SleefDFTXX<real, real2, MAXSHIFT, MAXBUTWIDTH>::execute(const real *s0, rea
13351339

13361340
//
13371341

1338-
const int tn = getThreadNum();
1339-
real *t[] = { x1[tn], x0[tn], d };
1340-
1342+
auto tid = this_thread::get_id();
1343+
real *t[] = { nullptr, nullptr, d };
1344+
{
1345+
unique_lock lock(mtx);
1346+
if (xn.count(tid) != 0) {
1347+
auto e = xn[tid];
1348+
t[0] = e.first;
1349+
t[1] = e.second;
1350+
} else {
1351+
t[0] = (real *)Sleef_malloc(sizeof(real) * 2 * (1L << log2len));
1352+
t[1] = (real *)Sleef_malloc(sizeof(real) * 2 * (1L << log2len));
1353+
xn[tid] = pair<real *, real *>{ t[0], t[1] };
1354+
}
1355+
}
1356+
13411357
const real *lb = s;
13421358
int nb = 0;
13431359

src/dft/dftcommon.cpp

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,10 @@ void SleefDFTXX<real, real2, MAXSHIFT, MAXBUTWIDTH>::freeTables() {
185185
tbl[N] = NULL;
186186
}
187187

188-
for(int i=0;i<nThread;i++) {
189-
Sleef_free(x1[i]);
190-
x1[i] = nullptr;
191-
Sleef_free(x0[i]);
192-
x0[i] = nullptr;
188+
for(auto a : xn) {
189+
Sleef_free(a.second.first);
190+
Sleef_free(a.second.second);
193191
}
194-
195-
free(x1);
196-
x1 = nullptr;
197-
free(x0);
198-
x0 = nullptr;
199192
}
200193

201194
template<typename real, typename real2, int MAXSHIFT, int MAXBUTWIDTH>
@@ -632,12 +625,6 @@ namespace {
632625

633626
waitUntilAllIdle(lock);
634627
}
635-
636-
int getThreadNum() {
637-
auto id = this_thread::get_id();
638-
if (thIdMap.count(id) == 0) return 0;
639-
return thIdMap.at(id);
640-
}
641628
};
642629

643630
#ifdef SLEEF_ENABLE_PARALLELFOR
@@ -650,14 +637,6 @@ namespace sleef_internal {
650637
function<void(int64_t, int64_t, int64_t)> func_) {
651638
#ifdef SLEEF_ENABLE_PARALLELFOR
652639
parallelForManager.run(start_, end_, inc_, func_);
653-
#endif
654-
}
655-
656-
int getThreadNum() {
657-
#ifndef SLEEF_ENABLE_PARALLELFOR
658-
return omp_get_thread_num();
659-
#else
660-
return parallelForManager.getThreadNum();
661640
#endif
662641
}
663642
}

src/dft/dftcommon.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,13 @@ namespace sleef_internal {
8080

8181
//
8282

83+
mutex mtx;
84+
8385
real **tbl[MAXBUTWIDTH+1];
8486
real *rtCoef0, *rtCoef1;
8587
uint32_t **perm;
8688

87-
real **x0, **x1;
89+
unordered_map<thread::id, pair<real *, real *>> xn;
8890

8991
int isa = 0;
9092
int planMode = 0;
@@ -221,8 +223,6 @@ namespace sleef_internal {
221223
extern FILE *defaultVerboseFP;
222224

223225
void parallelFor(int64_t start_, int64_t end_, int64_t inc_, std::function<void(int64_t, int64_t, int64_t)> func_);
224-
225-
int getThreadNum();
226226
}
227227

228228
using namespace sleef_internal;

0 commit comments

Comments
 (0)