Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 38 additions & 24 deletions cpp/daal/src/algorithms/k_nearest_neighbors/bf_knn_impl.i
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,24 @@ public:
const size_t inBlockSize = 128;
const size_t nOuterBlocks = nTest / outBlockSize + !!(nTest % outBlockSize);

TlsMem<FPType, cpu> tlsDistances(inBlockSize * outBlockSize);
TlsMem<int, cpu> tlsIdx(outBlockSize);
TlsMem<FPType, cpu> tlsKDistances(inBlockSize * k);
TlsMem<int, cpu> tlsKIndexes(inBlockSize * k);
TlsMem<FPType, cpu> tlsVoting(nClasses);
const int nThreads = _daal_threader_get_max_threads();

TArray<FPType, cpu> tlsDistancesArr(nThreads * inBlockSize * outBlockSize);
TArray<int, cpu> tlsIdxArr(nThreads * outBlockSize);
TArray<FPType, cpu> tlsKDistancesArr(nThreads * inBlockSize * k);
TArray<int, cpu> tlsKIndexesArr(nThreads * inBlockSize * k);
TArray<FPType, cpu> tlsVotingArr(nThreads * nClasses);

FPType * const tlsDistances = tlsDistancesArr.get();
int * const tlsIdx = tlsIdxArr.get();
FPType * const tlsKDistances = tlsKDistancesArr.get();
int * const tlsKIndexes = tlsKIndexesArr.get();
FPType * const tlsVoting = tlsVotingArr.get();

if (!tlsDistances || !tlsIdx || !tlsKDistances || !tlsKIndexes || !tlsVoting)
{
return services::Status(services::ErrorMemoryAllocationFailed);
}

SafeStatus safeStat;

Expand All @@ -122,9 +135,10 @@ public:
const size_t outerEnd = outerBlock + 1 == nOuterBlocks ? nTest : outerStart + outBlockSize;
const size_t outerSize = outerEnd - outerStart;

DAAL_CHECK_STATUS_THR(computeKNearestBlock(dist.get(), outerSize, inBlockSize, outerStart, nTrain, resultsToEvaluate, resultsToCompute,
nClasses, k, voteWeights, trainLabel, trainTable, testTable, testLabelTable, indicesTable,
distancesTable, tlsDistances, tlsIdx, tlsKDistances, tlsKIndexes, tlsVoting, nOuterBlocks));
DAAL_CHECK_STATUS_THR(computeKNearestBlock(dist.get(), outerSize, inBlockSize, outBlockSize, outerStart, nTrain, resultsToEvaluate,
resultsToCompute, nClasses, k, voteWeights, trainLabel, trainTable, testTable, testLabelTable,
indicesTable, distancesTable, tlsDistances, tlsIdx, tlsKDistances, tlsKIndexes, tlsVoting,
nOuterBlocks));
});

if (resultsToEvaluate & daal::algorithms::classifier::computeClassLabels)
Expand Down Expand Up @@ -174,12 +188,12 @@ protected:
};

services::Status computeKNearestBlock(PairwiseDistances<FPType, cpu> * distancesInstance, const size_t blockSize, const size_t trainBlockSize,
const size_t startTestIdx, const size_t nTrain, DAAL_UINT64 resultsToEvaluate, DAAL_UINT64 resultsToCompute,
const size_t nClasses, const size_t k, VoteWeights voteWeights, FPType * trainLabel,
const NumericTable * trainTable, const NumericTable * testTable, NumericTable * testLabelTable,
NumericTable * indicesTable, NumericTable * distancesTable, TlsMem<FPType, cpu> & tlsDistances,
TlsMem<int, cpu> & tlsIdx, TlsMem<FPType, cpu> & tlsKDistances, TlsMem<int, cpu> & tlsKIndexes,
TlsMem<FPType, cpu> & tlsVoting, size_t nOuterBlocks)
const size_t outBlockSize, const size_t startTestIdx, const size_t nTrain, DAAL_UINT64 resultsToEvaluate,
DAAL_UINT64 resultsToCompute, const size_t nClasses, const size_t k, VoteWeights voteWeights,
FPType * trainLabel, const NumericTable * trainTable, const NumericTable * testTable,
NumericTable * testLabelTable, NumericTable * indicesTable, NumericTable * distancesTable,
FPType * tlsDistances, int * const tlsIdx, FPType * const tlsKDistances, int * const tlsKIndexes,
FPType * const tlsVoting, size_t nOuterBlocks)
{
const size_t inBlockSize = trainBlockSize;
const size_t inRows = nTrain;
Expand Down Expand Up @@ -218,11 +232,11 @@ protected:
const BruteForceTask * tls = tlsTask.local(tid);
DAAL_CHECK_MALLOC_THR(tls);

FPType * distancesBuff = tlsDistances.local();
DAAL_CHECK_MALLOC_THR(distancesBuff);
const int threadID = daal::threader_get_max_current_thread_index();

FPType * distancesBuff = tlsDistances + threadID * inBlockSize * outBlockSize;

int * idx = tlsIdx.local();
DAAL_CHECK_MALLOC_THR(idx);
int * idx = tlsIdx + threadID * outBlockSize;

FPType * maxs = tls->maxs;
HeapType * heapsLocal = tls->heapsData;
Expand All @@ -246,11 +260,11 @@ protected:
}
});

int * kIndexes = tlsKIndexes.local();
DAAL_CHECK_MALLOC(kIndexes);
const int threadID = daal::threader_get_max_current_thread_index();

int * kIndexes = tlsKIndexes + threadID * inBlockSize * k;

FPType * kDistances = tlsKDistances.local();
DAAL_CHECK_MALLOC(kDistances);
FPType * kDistances = tlsKDistances + threadID * inBlockSize * k;

TArrayScalable<HeapType, cpu> heaps(iSize);

Expand Down Expand Up @@ -317,8 +331,8 @@ protected:
DAAL_CHECK_BLOCK_STATUS(testLabelRows);
int * testLabel = testLabelRows.get();

FPType * voting = tlsVoting.local();
DAAL_CHECK_MALLOC(voting);
// const int threadID = daal::threader_get_max_current_thread_index();
FPType * voting = tlsVoting + threadID * nClasses;

if (voteWeights == VoteWeights::voteUniform)
{
Expand Down
168 changes: 80 additions & 88 deletions ...gorithms/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
NumericTable * y, NumericTable * indices, NumericTable * distances,
const daal::algorithms::Parameter * par)
{
Status status;

typedef GlobalNeighbors<algorithmFpType, cpu> Neighbors;
typedef Heap<Neighbors, cpu> MaxHeap;
typedef kdtree_knn_classification::internal::Stack<SearchNode<algorithmFpType>, cpu> SearchStack;
Expand Down Expand Up @@ -172,33 +170,13 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
{
MaxHeap heap;
SearchStack stack;
bool initialized = false;
};
daal::tls<Local *> localTLS([&]() -> Local * {
Local * const ptr = service_scalable_calloc<Local, cpu>(1);
if (ptr)
{
if (!ptr->heap.init(heapSize))
{
status.add(services::ErrorMemoryAllocationFailed);
service_scalable_free<Local, cpu>(ptr);
return nullptr;
}
if (!ptr->stack.init(stackSize))
{
status.add(services::ErrorMemoryAllocationFailed);
ptr->heap.clear();
service_scalable_free<Local, cpu>(ptr);
return nullptr;
}
}
else
{
status.add(services::ErrorMemoryAllocationFailed);
}
return ptr;
});

DAAL_CHECK_STATUS_OK((status.ok()), status);
const int nThreads = daal::threader_get_max_threads_number();
services::internal::TArrayScalable<Local, cpu> localTLSArr(nThreads);
Local * localTLS = localTLSArr.get();
DAAL_CHECK_MALLOC(localTLS);

const auto maxThreads = threader_get_threads_number();
const size_t xColumnCount = x->getNumberOfColumns();
Expand All @@ -210,87 +188,101 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
bool isHomogenSOA = checkHomogenSOA<algorithmFpType, cpu>(data, soa_arrays);

daal::threader_for(blockCount, blockCount, [&](int iBlock) {
Local * const local = localTLS.local();
if (local)
{
services::Status s;

const size_t first = iBlock * rowsPerBlock;
const size_t last = min<cpu>(static_cast<decltype(xRowCount)>(first + rowsPerBlock), xRowCount);

const algorithmFpType radius = MaxVal::get();
data_management::BlockDescriptor<algorithmFpType> xBD;
const_cast<NumericTable &>(*x).getBlockOfRows(first, last - first, readOnly, xBD);
const algorithmFpType * const dx = xBD.getBlockPtr();
const int threadID = daal::threader_get_max_current_thread_index();
Local * const local = &localTLS[threadID];
Status s;

data_management::BlockDescriptor<int> indicesBD;
data_management::BlockDescriptor<algorithmFpType> distancesBD;
if (indices)
// initialize local heap and stack for the thread
if (!local->initialized)
{
if (!local->heap.init(heapSize))
{
s = indices->getBlockOfRows(first, last - first, writeOnly, indicesBD);
DAAL_CHECK_STATUS_THR(s);
safeStat.add(services::ErrorMemoryAllocationFailed);
return;
}
if (distances)
if (!local->stack.init(stackSize))
{
s = distances->getBlockOfRows(first, last - first, writeOnly, distancesBD);
DAAL_CHECK_STATUS_THR(s);
safeStat.add(services::ErrorMemoryAllocationFailed);
local->heap.clear();
return;
}
local->initialized = true;
}

if (labels)
{
const size_t yColumnCount = y->getNumberOfColumns();
data_management::BlockDescriptor<algorithmFpType> yBD;
y->getBlockOfRows(first, last - first, writeOnly, yBD);
auto * const dy = yBD.getBlockPtr();
const size_t first = iBlock * rowsPerBlock;
const size_t last = min<cpu>(static_cast<decltype(xRowCount)>(first + rowsPerBlock), xRowCount);

for (size_t i = 0; i < last - first; ++i)
{
findNearestNeighbors(&dx[i * xColumnCount], local->heap, local->stack, k, radius, kdTreeTable, rootTreeNodeIndex, data,
isHomogenSOA, soa_arrays);
s = predict(&(dy[i * yColumnCount]), local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i, nClasses);
DAAL_CHECK_STATUS_THR(s)
}
const algorithmFpType radius = MaxVal::get();
data_management::BlockDescriptor<algorithmFpType> xBD;
const_cast<NumericTable &>(*x).getBlockOfRows(first, last - first, readOnly, xBD);
const algorithmFpType * const dx = xBD.getBlockPtr();

s |= y->releaseBlockOfRows(yBD);
DAAL_CHECK_STATUS_THR(s);
}
else
{
for (size_t i = 0; i < last - first; ++i)
{
findNearestNeighbors(&dx[i * xColumnCount], local->heap, local->stack, k, radius, kdTreeTable, rootTreeNodeIndex, data,
isHomogenSOA, soa_arrays);
s = predict(nullptr, local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i, nClasses);
DAAL_CHECK_STATUS_THR(s)
}
}
data_management::BlockDescriptor<int> indicesBD;
data_management::BlockDescriptor<algorithmFpType> distancesBD;
if (indices)
{
s = indices->getBlockOfRows(first, last - first, writeOnly, indicesBD);
DAAL_CHECK_STATUS_THR(s);
}
if (distances)
{
s = distances->getBlockOfRows(first, last - first, writeOnly, distancesBD);
DAAL_CHECK_STATUS_THR(s);
}

if (labels)
{
const size_t yColumnCount = y->getNumberOfColumns();
data_management::BlockDescriptor<algorithmFpType> yBD;
y->getBlockOfRows(first, last - first, writeOnly, yBD);
auto * const dy = yBD.getBlockPtr();

if (indices)
for (size_t i = 0; i < last - first; ++i)
{
s |= indices->releaseBlockOfRows(indicesBD);
findNearestNeighbors(&dx[i * xColumnCount], local->heap, local->stack, k, radius, kdTreeTable, rootTreeNodeIndex, data, isHomogenSOA,
soa_arrays);
s = predict(&(dy[i * yColumnCount]), local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i, nClasses);
DAAL_CHECK_STATUS_THR(s);
}

s = y->releaseBlockOfRows(yBD);
DAAL_CHECK_STATUS_THR(s);
if (distances)
}
else
{
for (size_t i = 0; i < last - first; ++i)
{
s |= distances->releaseBlockOfRows(distancesBD);
findNearestNeighbors(&dx[i * xColumnCount], local->heap, local->stack, k, radius, kdTreeTable, rootTreeNodeIndex, data, isHomogenSOA,
soa_arrays);
s = predict(nullptr, local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i, nClasses);
DAAL_CHECK_STATUS_THR(s);
}
DAAL_CHECK_STATUS_THR(s);
}

const_cast<NumericTable &>(*x).releaseBlockOfRows(xBD);
if (indices)
{
s = indices->releaseBlockOfRows(indicesBD);
DAAL_CHECK_STATUS_THR(s);
}
if (distances)
{
s = distances->releaseBlockOfRows(distancesBD);
DAAL_CHECK_STATUS_THR(s);
}
});

DAAL_CHECK_SAFE_STATUS()
const_cast<NumericTable &>(*x).releaseBlockOfRows(xBD);
});

localTLS.reduce([&](Local * ptr) -> void {
if (ptr)
for (int i = 0; i < nThreads; ++i)
{
if (localTLS[i].initialized)
{
ptr->stack.clear();
ptr->heap.clear();
service_scalable_free<Local, cpu>(ptr);
localTLS[i].heap.clear();
localTLS[i].stack.clear();
}
});
return status;
}

return safeStat.detach();
}

template <typename algorithmFpType, CpuType cpu>
Expand Down
Loading