Skip to content

Accessing cuCollections from CUDA.jl --- static_map in particular #2966

@mytraya-gattu

Description

@mytraya-gattu

I have a very particular problem, whose variations others might find useful as well.
I need to look up the indices of unsigned integers from a large constant array (typical sizes between 1e7-4e9) on the GPU from GPU threads: The elements of the array are known to be unique (and can be sorted beforehand). I think a static map from https://github.com/NVIDIA/cuCollections would be ideal for my use case. I want to be able to call from an arbitrary thread on the GPU the following function: Given x (which is guaranteed to be a key), what value does it correspond to?

Describe the solution you'd like

Ideally, I would like to call a function in CUDA.jl that can build a hash table for me based on arrays of keys and values, and then access the hash table from within a kernel.

Describe alternatives you've considered

I have tried to use ChatGPT / examples from cuCollections to piece together some C++ code that I can build into a library and access from julia using ccall; but I keep running into issues: In particular, I do not understand how to pass the pointer of the static_map to julia --- In julia, I would have just returned a function like lookup(x) which would give out my value.

// hash_lookup.cu
#include <cuco/static_map.cuh>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <cuda/std/limits>
#include <cuda_runtime.h>

#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <unordered_set>
#include <random>
#include <vector>
#include <cmath>
#include <type_traits>
#include <utility>

inline constexpr double kLoadFactor = 0.5;

// ---------------- CUDA error helper ----------------
#define CUDA_CHECK(call) do {                                     \
  cudaError_t _e = (call);                                        \
  if (_e != cudaSuccess) {                                        \
    fprintf(stderr,"CUDA error %s:%d: %s\n",                      \
            __FILE__, __LINE__, cudaGetErrorString(_e));          \
    std::exit(1);                                                 \
  }                                                               \
} while(0)

// ---------------- map builder (from DEVICE keys): key -> index ----------------
template <class Key, class Value>
cuco::static_map<Key, Value>
build_map_from_device_keys(const Key* d_keys, std::size_t num_keys)
{
  static_assert(std::is_unsigned_v<Value>, "Value should hold indices (e.g., uint32_t/uint64_t).");

  constexpr Key   kEmptyKey   = Key{0};  // you guaranteed 0 never appears
  constexpr Value kEmptyValue = cuda::std::numeric_limits<Value>::max();

  const std::size_t capacity = static_cast<std::size_t>(std::ceil(num_keys / kLoadFactor));

  cuco::static_map<Key, Value> map{
      capacity, cuco::empty_key{kEmptyKey}, cuco::empty_value{kEmptyValue}};

  // Iterator of pairs { d_keys[i], Value(i) }
  auto pairs = thrust::make_transform_iterator(
      thrust::counting_iterator<std::size_t>{0},
      [k = d_keys] __device__ (std::size_t i) {
        return cuco::pair<Key, Value>{k[i], static_cast<Value>(i)};
      });

  map.insert(pairs, pairs + num_keys);
  return map;  // map owns its device memory; d_keys can be freed after insert
}

// ---------------- device lookup + kernel (pass device-view by arg) ----------------
template <class Ref, class Key>
__device__ __forceinline__
auto map_lookup(Ref ref, Key k) -> decltype(ref.find(k)->second) {
  return ref.find(k)->second;  // key guaranteed present
}

template <class Ref, class Key, class Value>
__global__ void lookup_kernel(Ref ref, const Key* query_keys, Value* out, std::size_t N)
{
  std::size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < N) out[i] = map_lookup(ref, query_keys[i]);
}

// // ---- detect how to get the device view across cuco versions (C++17) ----
template <class M, class = void>
struct has_ref_method : std::false_type {};
template <class M>
struct has_ref_method<M, std::void_t<decltype(std::declval<M&>().ref())>> : std::true_type {};

template <class M>
auto get_ref(M& m) {
    return m.ref(cuco::find);
    // return m;
  // if constexpr (has_ref_method<M>::value) {
  //   return m.ref(cuco::find);              // newer cuco
  // } else {
  //   return m.get_device_view();  // older cuco
  // }
  // return m.get_device_view();
}

// ---------------- main test ----------------
int main() {
  using Key   = uint32_t;
  using Value = uint32_t;

  constexpr std::size_t N = 10'000;

  // 1) generate N unique, non-zero random keys
  std::mt19937_64 rng{1234567};
  std::uniform_int_distribution<uint32_t> dist(1u, 0xFFFFFFFEu); // exclude 0 sentinel

  std::unordered_set<uint32_t> used;
  used.reserve(N*2);
  std::vector<uint32_t> h_keys;
  h_keys.reserve(N);
  while (h_keys.size() < N) {
    uint32_t k = dist(rng);
    if (used.insert(k).second) h_keys.push_back(k);
  }
  // optional: shuffle
  std::shuffle(h_keys.begin(), h_keys.end(), rng);

  // 2) copy keys to device
  Key* d_keys = nullptr;
  CUDA_CHECK(cudaMalloc(&d_keys, N * sizeof(Key)));
  CUDA_CHECK(cudaMemcpy(d_keys, h_keys.data(), N * sizeof(Key), cudaMemcpyHostToDevice));

  // 3) build the map on host (uses device pointers internally)
  auto map = build_map_from_device_keys<Key, Value>(d_keys, N);

//   // 4) get device view/handle (works across cuco versions via get_ref)
  auto ref = get_ref(map);

  // 5) device output buffer
  Value* d_out = nullptr;
  CUDA_CHECK(cudaMalloc(&d_out, N * sizeof(Value)));

  // 6) launch lookup kernel
  dim3 block(256);
  dim3 grid(static_cast<unsigned>((N + block.x - 1) / block.x));
  lookup_kernel<<<grid, block>>>(ref, d_keys, d_out, N);
  CUDA_CHECK(cudaPeekAtLastError());
  CUDA_CHECK(cudaDeviceSynchronize());

  // 7) copy results back
  std::vector<Value> h_out(N);
  CUDA_CHECK(cudaMemcpy(h_out.data(), d_out, N * sizeof(Value), cudaMemcpyDeviceToHost));

  // 8) verify: since we queried with the original keys in order,
  //    we expect map[key[i]] == i for all i
  std::size_t mismatches = 0;
  for (std::size_t i = 0; i < N; ++i) {
    if (h_out[i] != static_cast<Value>(i)) {
      if (mismatches < 10) {
        std::fprintf(stderr, "mismatch at i=%zu: got %u, want %u\n",
                     i, (unsigned)h_out[i], (unsigned)i);
      }
      ++mismatches;
    }
  }

  if (mismatches == 0) {
    std::printf("OK: %zu lookups matched indices.\n", N);
  } else {
    std::printf("FAIL: %zu mismatches out of %zu.\n", mismatches, N);
  }

  // 9) cleanup
  CUDA_CHECK(cudaFree(d_out));
  CUDA_CHECK(cudaFree(d_keys));
  return (mismatches == 0) ? 0 : 1;
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions