-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparlay_sample_og.cpp
More file actions
151 lines (125 loc) · 4.7 KB
/
parlay_sample_og.cpp
File metadata and controls
151 lines (125 loc) · 4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <iostream>
#include <vector>
#include <algorithm>
#include <chrono>
#include <random>
#include <string>
#include <cstdint>
#include <limits>
#include <type_traits>
#ifndef SS_STOP_STAGE
#define SS_STOP_STAGE 0
#endif
#include "src/one_pass/parlay_sample_og.hpp"
class Timer {
std::chrono::high_resolution_clock::time_point start;
public:
Timer() { reset(); }
void reset() { start = std::chrono::high_resolution_clock::now(); }
double elapsed() const {
auto end = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double>(end - start).count();
}
};
// SplitMix64 核心
static inline uint64_t splitmix64(uint64_t x) {
uint64_t z = x + 0x9E3779B97F4A7C15ULL;
z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ULL;
z = (z ^ (z >> 27)) * 0x94D049BB133111EBULL;
z = z ^ (z >> 31);
return z;
}
/**
* 模板化并行生成器
* T 支持 int32_t / int64_t(也可扩展)
*/
template <typename T>
void fast_parallel_gen(parlay::sequence<T>& data, size_t n, uint64_t K, uint64_t seed = 42) {
static_assert(std::is_same_v<T, int32_t> || std::is_same_v<T, int64_t>,
"fast_parallel_gen currently supports int32_t/int64_t");
// 防止 K=0 导致取模异常
if (K == 0) K = 1;
// 让生成值始终可安全 cast 到 T:
// int32_t: 最大可用范围是 [0, 2^31-1]
// int64_t: 最大可用范围是 [0, 2^63-1]
const uint64_t max_plus_1 = static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1ULL;
const uint64_t K_eff = std::min<uint64_t>(K, max_plus_1);
parlay::parallel_for(0, n, [&](size_t i) {
uint64_t z = splitmix64(static_cast<uint64_t>(i) + seed);
uint64_t v = z % K_eff; // v ∈ [0, K_eff-1],可安全转换
data[i] = static_cast<T>(v);
});
}
template <typename T>
int run_bench(size_t n, uint64_t K) {
std::cout << "========================================\n";
std::cout << "Running Benchmark (Fast SplitMix64)\n";
std::cout << "Type: " << (std::is_same_v<T,int32_t> ? "int32" : "int64") << "\n";
std::cout << "Data Size: " << n << " elements\n";
std::cout << "========================================\n";
// 1) 生成
std::cout << "Generating random data..." << std::flush;
Timer gen_timer;
parlay::sequence<T> data(n);
fast_parallel_gen<T>(data, n, K);
std::cout << " Done. (" << gen_timer.elapsed() << "s)\n";
// 2) 排序
std::cout << "Sorting...\n" << std::flush;
Timer sort_timer;
bool ascending = true; // 你可以改成 false 来测试降序
parlay::internal::sample_sort_inplace(parlay::make_slice(data));
double time_taken = sort_timer.elapsed();
std::cout << "Sorting Done.\n";
std::cout << "----------------------------------------\n";
std::cout << "Sort Time: " << time_taken << " seconds\n";
std::cout << "Throughput: " << (static_cast<double>(n) / time_taken / 1e6) << " M/sec\n";
std::cout << "----------------------------------------\n";
#if SS_STOP_STAGE != 0
return 0;
#endif
// 3) 并行验证
std::cout << "Verifying correctness (Parallel)..." << std::flush;
Timer verify_timer;
size_t num_workers = parlay::num_workers();
std::vector<char> results(num_workers, 1);
parlay::parallel_for(0, num_workers, [&](size_t i) {
size_t start = (i * n) / num_workers;
size_t end = ((i + 1) * n) / num_workers;
size_t check_end = (i == num_workers - 1) ? end : (end + 1);
if (start < check_end) {
if (!std::is_sorted(data.begin() + start, data.begin() + check_end, [ascending](const T& a, const T& b) {
return ascending ? a < b : a > b;
})) {
results[i] = 0;
}
}
});
bool is_sorted = true;
for (char r : results) {
if (r == 0) { is_sorted = false; break; }
}
if (is_sorted) {
std::cout << " PASSED ✅ (" << verify_timer.elapsed() << "s)\n";
} else {
std::cout << " FAILED ❌\n";
size_t lim = (n > 1) ? std::min(n - 1, static_cast<size_t>(10)) : 0;
for (size_t i = 0; i < lim; i++) {
if (data[i] > data[i + 1]) {
std::cout << "Error at index " << i << ": "
<< data[i] << " > " << data[i + 1] << "\n";
}
}
return 1;
}
return 0;
}
int main(int argc, char* argv[]) {
// 只接收一个参数: n
size_t n = 1000000000ULL;
if (argc > 1) n = std::stoull(argv[1]);
// ===== 手动改这两行即可 =====
using BenchType = int64_t; // 改成 int32_t 或 int64_t
uint64_t K = 1000000000; // 你要的取值范围
// ===========================
return run_bench<BenchType>(n, K);
}