Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions awkward-cpp/src/cpu-kernels/awkward_argsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,48 @@
#include <algorithm>
#include <cmath>
#include <numeric>
#include <type_traits>
#include <vector>

#include "awkward/kernels.h"

// Explicit specializations must appear before implicit instantiations.
template <typename T>
bool argsort_order_ascending(T l, T r);

template <typename T>
bool argsort_order_descending(T l, T r);

template <>
bool argsort_order_ascending(bool l, bool r)
{
return l < r;
}

template <>
bool argsort_order_descending(bool l, bool r)
{
return l > r;
}

template <typename T>
bool argsort_order_ascending(T l, T r)
{
return !std::isnan(static_cast<double>(r)) && (std::isnan(static_cast<double>(l)) || l < r);
if constexpr (std::is_integral_v<T>) {
return l < r;
} else {
return !std::isnan(r) && (std::isnan(l) || l < r);
}
}

template <typename T>
bool argsort_order_descending(T l, T r)
{
return !std::isnan(static_cast<double>(r)) && (std::isnan(static_cast<double>(l)) || l > r);
if constexpr (std::is_integral_v<T>) {
return l > r;
} else {
return !std::isnan(r) && (std::isnan(l) || l > r);
}
}

template <typename T>
Expand All @@ -39,15 +67,26 @@ ERROR awkward_argsort(
int64_t* segment_start = toptr + start_off;
int64_t* segment_stop = toptr + stop_off;

auto comparator = [&fromptr, ascending](int64_t i1, int64_t i2) {
if (ascending) return argsort_order_ascending<T>(fromptr[i1], fromptr[i2]);
else return argsort_order_descending<T>(fromptr[i1], fromptr[i2]);
};

if (stable) {
std::stable_sort(segment_start, segment_stop, comparator);
if (ascending) {
std::stable_sort(segment_start, segment_stop, [&fromptr](int64_t i1, int64_t i2) {
return argsort_order_ascending<T>(fromptr[i1], fromptr[i2]);
});
} else {
std::stable_sort(segment_start, segment_stop, [&fromptr](int64_t i1, int64_t i2) {
return argsort_order_descending<T>(fromptr[i1], fromptr[i2]);
});
}
} else {
std::sort(segment_start, segment_stop, comparator);
if (ascending) {
std::sort(segment_start, segment_stop, [&fromptr](int64_t i1, int64_t i2) {
return argsort_order_ascending<T>(fromptr[i1], fromptr[i2]);
});
} else {
std::sort(segment_start, segment_stop, [&fromptr](int64_t i1, int64_t i2) {
return argsort_order_descending<T>(fromptr[i1], fromptr[i2]);
});
}
}

std::transform(segment_start, segment_stop, segment_start, [start_off](int64_t j) {
Expand Down Expand Up @@ -255,15 +294,3 @@ ERROR awkward_argsort_float64(
ascending,
stable);
}

template <>
bool argsort_order_ascending(bool l, bool r)
{
return l < r;
}

template <>
bool argsort_order_descending(bool l, bool r)
{
return l > r;
}
44 changes: 30 additions & 14 deletions awkward-cpp/src/cpu-kernels/awkward_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,46 @@
#include <algorithm>
#include <cmath>
#include <numeric>
#include <type_traits>
#include <vector>

// Explicit specializations must appear before implicit instantiations.
template <typename T>
bool sort_order_ascending(T l, T r);

template <typename T>
bool sort_order_descending(T l, T r);

template <>
bool sort_order_ascending(bool l, bool r)
{
return l < r;
}

template <>
bool sort_order_descending(bool l, bool r)
{
return l > r;
}

template <typename T>
bool sort_order_ascending(T l, T r)
{
return !std::isnan(static_cast<double>(r)) && (std::isnan(static_cast<double>(l)) || l < r);
if constexpr (std::is_integral_v<T>) {
return l < r;
} else {
return !std::isnan(r) && (std::isnan(l) || l < r);
}
}

template <typename T>
bool sort_order_descending(T l, T r)
{
return !std::isnan(static_cast<double>(r)) && (std::isnan(static_cast<double>(l)) || l > r);
if constexpr (std::is_integral_v<T>) {
return l > r;
} else {
return !std::isnan(r) && (std::isnan(l) || l > r);
}
}

template <typename T>
Expand Down Expand Up @@ -277,15 +305,3 @@ ERROR awkward_sort_float64(
ascending,
stable);
}

template <>
bool sort_order_ascending(bool l, bool r)
{
return l < r;
}

template <>
bool sort_order_descending(bool l, bool r)
{
return l > r;
}
57 changes: 57 additions & 0 deletions tests/test_4090_sort_int64_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np

import awkward as ak


def test_sort_int64_above_2_53_ascending():
"""ak.sort on int64 values > 2^53 must match numpy, ascending."""
base = 2**60
vals = [base + 3, base + 1, base + 2, base + 0]
arr = ak.Array([vals])
result = ak.sort(arr, ascending=True)
expected = np.sort(np.array(vals, dtype=np.int64)).tolist()
assert result.tolist()[0] == expected


def test_argsort_int64_above_2_53_ascending():
"""ak.argsort on int64 values > 2^53 must match numpy, ascending."""
base = 2**60
vals = [base + 3, base + 1, base + 2, base + 0]
arr = ak.Array([vals])
result = ak.argsort(arr, ascending=True)
expected = np.argsort(np.array(vals, dtype=np.int64), kind="stable").tolist()
assert result.tolist()[0] == expected


def test_sort_uint64_above_2_53_ascending():
"""ak.sort on uint64 values > 2^53 must match numpy, ascending."""
base = 2**60
vals = [base + 3, base + 1, base + 2, base + 0]
arr = ak.Array(np.array([vals], dtype=np.uint64))
result = ak.sort(arr, ascending=True)
expected = np.sort(np.array(vals, dtype=np.uint64)).tolist()
assert result.tolist()[0] == expected


def test_sort_float_nan_ascending():
"""NaN handling for float sort must be unchanged: NaNs sort last ascending."""
nan = float("nan")
arr = ak.Array([[1.0, nan, 2.0, nan, 0.5]])
result = ak.sort(arr, ascending=True)
values = result.tolist()[0]
# NaNs go last; non-NaN values sorted ascending
non_nan = [v for v in values if not (v != v)]
nans = [v for v in values if v != v]
assert non_nan == sorted(non_nan)
assert len(nans) == 2


def test_sort_bool_ascending():
"""Bool sort ascending: False < True."""
arr = ak.Array([[True, False, True, False]])
result = ak.sort(arr, ascending=True)
assert result.tolist() == [[False, False, True, True]]
Loading