Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 97 additions & 154 deletions src/qseek/ext/delay_sum.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
typedef struct {
int32_t *shifts;
float *weights;
npy_bool masked;
int masked;
int trace_group;
} Node;

typedef struct {
Expand All @@ -42,137 +43,98 @@ static inline npy_intp imax(npy_intp a, npy_intp b) { return a > b ? a : b; }
static inline npy_intp imin(npy_intp a, npy_intp b) { return a < b ? a : b; }

// Function to check NumPy array dtype
static inline int check_array_dtype(PyArrayObject *arr, int expected_type) {
static inline int check_array(PyObject *arr, int expected_type) {
if (!PyArray_Check(arr)) {
PyErr_SetString(PyExc_TypeError, "Input must be a NumPy array");
return 0;
}
if (PyArray_TYPE(arr) != expected_type) {
if (PyArray_TYPE((PyArrayObject *)arr) != expected_type) {
PyErr_Format(PyExc_TypeError, "Input array must be of type %s",
expected_type == NPY_FLOAT32 ? "float32" : "unknown");
return 0;
}
if (!PyArray_ISCONTIGUOUS(arr)) {
if (!PyArray_ISCONTIGUOUS((PyArrayObject *)arr)) {
PyErr_SetString(PyExc_ValueError, "Input array must be contiguous");
return 0;
}
return 1;
}

// Prepare function equivalent to Mojo's prepare
static PyObject *prepare(PyObject *traces, PyObject *offsets, PyObject *shifts,
PyObject *weights, PyObject *node_mask,
Trace **traces_list, Node **nodes_list,
int32_t *min_shift, int32_t *max_shift) {
Py_ssize_t n_traces = PyList_Size(traces);
PyArrayObject *shifts_arr =
(PyArrayObject *)PyArray_ContiguousFromObject(shifts, NPY_INT32, 2, 2);
PyArrayObject *weights_arr =
(PyArrayObject *)PyArray_ContiguousFromObject(weights, NPY_FLOAT32, 2, 2);
PyArrayObject *offsets_arr =
(PyArrayObject *)PyArray_ContiguousFromObject(offsets, NPY_INT32, 1, 1);
PyArrayObject *node_mask_arr = NULL;

if (!shifts_arr || !weights_arr || !offsets_arr) {
Py_XDECREF(shifts_arr);
Py_XDECREF(weights_arr);
Py_XDECREF(offsets_arr);
return NULL;
}

if (n_traces == 0) {
PyErr_SetString(PyExc_ValueError, "Input traces must be a non-empty list");
goto cleanup;
}

npy_intp *shifts_shape = PyArray_SHAPE(shifts_arr);
npy_intp n_nodes = shifts_shape[0];

if (node_mask == Py_None) {
node_mask = PyArray_ZEROS(1, &n_nodes, NPY_BOOL, 0);
if (!node_mask) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate node activation");
goto cleanup;
}
} else {
if (!check_array_dtype((PyArrayObject *)node_mask, NPY_BOOL)) {
goto cleanup;
}
if (PyArray_SHAPE((PyArrayObject *)node_mask)[0] != n_nodes) {
PyErr_SetString(PyExc_ValueError,
"Node mask must have the same number of elements as "
"nodes in shifts array");
goto cleanup;
}
}
node_mask_arr = (PyArrayObject *)node_mask;

if (n_nodes == 0) {
PyErr_SetString(PyExc_ValueError,
"Number of nodes must be greater than zero");
goto cleanup;
static int prepare(PyObject *nodes, PyObject *traces, PyObject *offsets,
Node **nodes_list, Trace **traces_list, int32_t *min_shift,
int32_t *max_shift) {
if (!PyList_Check(nodes) || !PyList_Check(traces)) {
PyErr_SetString(PyExc_TypeError, "nodes and traces must be lists");
return 0;
}

if (!check_array_dtype(weights_arr, NPY_FLOAT32) ||
!check_array_dtype(offsets_arr, NPY_INT32) ||
!check_array_dtype(shifts_arr, NPY_INT32) ||
!check_array_dtype(node_mask_arr, NPY_BOOL)) {
goto cleanup;
}
Py_ssize_t n_traces = PyList_Size(traces);
Py_ssize_t n_nodes = PyList_Size(nodes);

if (shifts_shape[0] != PyArray_SHAPE(weights_arr)[0] ||
shifts_shape[1] != PyArray_SHAPE(weights_arr)[1]) {
PyErr_SetString(PyExc_ValueError,
"Shifts and weights must have the same shape");
goto cleanup;
}
if (n_traces != PyArray_SHAPE(offsets_arr)[0]) {
if (!check_array(offsets, NPY_INT32) ||
PyArray_SHAPE((PyArrayObject *)offsets)[0] != n_traces) {
PyErr_SetString(PyExc_ValueError,
"Number of arrays must match number of offsets");
goto cleanup;
}
if (shifts_shape[1] != n_traces) {
PyErr_SetString(PyExc_ValueError,
"Shifts must have the same number of columns as traces");
goto cleanup;
}
if (n_nodes != PyArray_SHAPE(node_mask_arr)[0]) {
PyErr_SetString(PyExc_ValueError,
"Number of nodes must match number of activation flags");
goto cleanup;
return 0;
}

int32_t *offsets_data = (int32_t *)PyArray_DATA(offsets_arr);
int32_t *shifts_data = (int32_t *)PyArray_DATA(shifts_arr);
float *weights_data = (float *)PyArray_DATA(weights_arr);
npy_bool *node_mask_data = (npy_bool *)PyArray_DATA(node_mask_arr);
int32_t *offsets_data = (int32_t *)PyArray_DATA((PyArrayObject *)offsets);

*traces_list = (Trace *)malloc(n_traces * sizeof(Trace));
*nodes_list = (Node *)malloc(n_nodes * sizeof(Node));
if (!*traces_list || !*nodes_list) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate memory");
goto cleanup;
return 0;
}

for (npy_intp i = 0; i < n_traces; i++) {
PyArrayObject *trace = (PyArrayObject *)PyArray_ContiguousFromObject(
PyList_GetItem(traces, i), NPY_FLOAT32, 1, 1);
if (!trace)
goto cleanup_traces;
if (!check_array_dtype(trace, NPY_FLOAT32)) {
PyObject *trace = PyList_GET_ITEM(traces, i);
if (!check_array(trace, NPY_FLOAT32)) {
Py_DECREF(trace);
goto cleanup_traces;
free(*traces_list);
free(*nodes_list);
return 0;
}
(*traces_list)[i].data = (float *)PyArray_DATA(trace);
(*traces_list)[i].size = PyArray_SIZE(trace);
(*traces_list)[i].data = (float *)PyArray_DATA((PyArrayObject *)trace);
(*traces_list)[i].size = PyArray_SIZE((PyArrayObject *)trace);
(*traces_list)[i].offset = offsets_data[i];
Py_DECREF(trace); // We keep the data pointer, but release the array object
}

for (npy_intp i = 0; i < n_nodes; i++) {
(*nodes_list)[i].shifts = shifts_data + i * n_traces;
(*nodes_list)[i].weights = weights_data + i * n_traces;
(*nodes_list)[i].masked = node_mask_data[i];
PyObject *node_tuple = PyList_GET_ITEM(nodes, i);
if (!PyTuple_Check(node_tuple) || PyTuple_Size(node_tuple) < 3) {
PyErr_SetString(
PyExc_TypeError,
"Each node must be a tuple of (shifts, weights, masked, ...)");
free(*nodes_list);
free(*traces_list);
return 0;
}
PyObject *shifts_arr = (PyObject *)PyTuple_GET_ITEM(node_tuple, 0);
PyObject *weights_arr = (PyObject *)PyTuple_GET_ITEM(node_tuple, 1);
PyObject *masked_obj = (PyObject *)PyTuple_GET_ITEM(node_tuple, 2);
PyObject *trace_group_obj = (PyObject *)PyTuple_GET_ITEM(node_tuple, 3);

if (!check_array(shifts_arr, NPY_INT32) ||
!check_array(weights_arr, NPY_FLOAT32)) {
free(*nodes_list);
free(*traces_list);
return 0;
}
if (PyArray_NDIM((PyArrayObject *)shifts_arr) != 1 ||
PyArray_NDIM((PyArrayObject *)weights_arr) != 1 ||
PyArray_SIZE((PyArrayObject *)shifts_arr) != n_traces ||
PyArray_SIZE((PyArrayObject *)weights_arr) != n_traces) {
PyErr_SetString(PyExc_ValueError, "Shifts and weights must be 1D arrays");
free(*nodes_list);
free(*traces_list);
return 0;
}

(*nodes_list)[i].shifts = PyArray_DATA((PyArrayObject *)shifts_arr);
(*nodes_list)[i].weights = PyArray_DATA((PyArrayObject *)weights_arr);
(*nodes_list)[i].masked = PyObject_IsTrue(masked_obj);
// (*nodes_list)[i].trace_group = (int32_t)PyLong_AsLong(trace_group_obj);
(*nodes_list)[i].trace_group = 0;
}

*min_shift = INT32_MAX;
Expand All @@ -186,49 +148,33 @@ static PyObject *prepare(PyObject *traces, PyObject *offsets, PyObject *shifts,
*max_shift = (*max_shift > idx_end) ? *max_shift : idx_end;
}
}

Py_DECREF(shifts_arr);
Py_DECREF(weights_arr);
Py_DECREF(offsets_arr);
return traces;

cleanup_traces:
free(*traces_list);
free(*nodes_list);
cleanup:
Py_XDECREF(shifts_arr);
Py_XDECREF(weights_arr);
Py_XDECREF(offsets_arr);
return NULL;
return 1;
}

static PyObject *delay_sum(PyObject *self, PyObject *args, PyObject *kwargs) {
PyObject *traces, *offsets, *shifts, *weights, *stack, *node_mask,
*shift_range;
PyObject *traces, *offsets, *nodes, *stack, *shift_range;
stack = Py_None; // Default to None if not provided
node_mask = Py_None;
shift_range = Py_None;
int n_threads = 1;

static char *kwlist[] = {"traces", "offsets", "shifts",
"weights", "node_mask", "stack",
static char *kwlist[] = {"traces", "offsets", "nodes", "stack",
"shift_range", "n_threads", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOO|OOOi", kwlist, &traces,
&offsets, &shifts, &weights, &node_mask,
&stack, &shift_range, &n_threads)) {
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOO|OOi", kwlist, &traces,
&offsets, &nodes, &stack, &shift_range,
&n_threads)) {
return NULL;
}

Trace *traces_list;
Node *nodes_list;
int32_t min_shift, max_shift;
if (!prepare(traces, offsets, shifts, weights, node_mask, &traces_list,
&nodes_list, &min_shift, &max_shift)) {
if (!prepare(nodes, traces, offsets, &nodes_list, &traces_list, &min_shift,
&max_shift)) {
return NULL;
}

npy_intp n_traces = PyList_Size(traces);
npy_intp n_nodes = PyArray_SHAPE((PyArrayObject *)shifts)[0];
npy_intp n_nodes = PyList_Size(nodes);
npy_intp stack_size = max_shift - min_shift;
if (shift_range != Py_None) {
if (!PyTuple_Check(shift_range) || PyTuple_Size(shift_range) != 2 ||
Expand All @@ -254,7 +200,7 @@ static PyObject *delay_sum(PyObject *self, PyObject *args, PyObject *kwargs) {
}

if (stack != Py_None) {
if (!check_array_dtype((PyArrayObject *)stack, NPY_FLOAT32)) {
if (!check_array(stack, NPY_FLOAT32)) {
free(traces_list);
free(nodes_list);
return NULL;
Expand Down Expand Up @@ -339,23 +285,21 @@ static PyObject *delay_sum(PyObject *self, PyObject *args, PyObject *kwargs) {
// stack_and_reduce function
static PyObject *delay_sum_reduce(PyObject *self, PyObject *args,
PyObject *kwargs) {
PyObject *traces, *offsets, *shifts, *weights, *node_mask, *node_stack_max,
*node_stack_max_idx, *shift_range;
node_mask = Py_None;
PyObject *traces, *offsets, *nodes, *node_stack_max, *node_stack_max_idx,
*shift_range;
node_stack_max = Py_None;
node_stack_max_idx = Py_None;
shift_range = Py_None;

int n_threads = 1;

static char *kwlist[] = {
"traces", "offsets", "shifts", "weights",
"node_mask", "shift_range", "node_stack_max", "node_stack_max_idx",
"n_threads", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOO|OOOOi", kwlist, &traces,
&offsets, &shifts, &weights, &node_mask,
&shift_range, &node_stack_max,
&node_stack_max_idx, &n_threads)) {
static char *kwlist[] = {"traces", "offsets",
"nodes", "shift_range",
"node_stack_max", "node_stack_max_idx",
"n_threads", NULL};
if (!PyArg_ParseTupleAndKeywords(
args, kwargs, "OOO|OOOi", kwlist, &traces, &offsets, &nodes,
&shift_range, &node_stack_max, &node_stack_max_idx, &n_threads)) {
return NULL;
}

Expand All @@ -369,12 +313,12 @@ static PyObject *delay_sum_reduce(PyObject *self, PyObject *args,
Trace *traces_list;
Node *nodes_list;
int32_t min_shift, max_shift;
if (!prepare(traces, offsets, shifts, weights, node_mask, &traces_list,
&nodes_list, &min_shift, &max_shift))
if (!prepare(nodes, traces, offsets, &nodes_list, &traces_list, &min_shift,
&max_shift))
return NULL;

npy_intp n_traces = PyList_Size(traces);
npy_intp n_nodes = PyArray_SHAPE((PyArrayObject *)shifts)[0];
npy_intp n_nodes = PyList_Size(nodes);
npy_intp stack_size = max_shift - min_shift;

if (shift_range != Py_None) {
Expand All @@ -401,8 +345,8 @@ static PyObject *delay_sum_reduce(PyObject *self, PyObject *args,
}

if (node_stack_max != Py_None && node_stack_max_idx != Py_None) {
if (!check_array_dtype((PyArrayObject *)node_stack_max, NPY_FLOAT32) ||
!check_array_dtype((PyArrayObject *)node_stack_max_idx, NPY_INT32) ||
if (!check_array(node_stack_max, NPY_FLOAT32) ||
!check_array(node_stack_max_idx, NPY_INT32) ||
PyArray_NDIM((PyArrayObject *)node_stack_max) != 1 ||
PyArray_NDIM((PyArrayObject *)node_stack_max_idx) != 1 ||
PyArray_SHAPE((PyArrayObject *)node_stack_max_idx)[0] != stack_size ||
Expand Down Expand Up @@ -506,12 +450,13 @@ static PyObject *delay_sum_reduce(PyObject *self, PyObject *args,
// for (; i < tile_size - (tile_size % LANE_WIDTH); i += LANE_WIDTH) {
// npy_intp res_idx = tile_start_idx + i;
// simde__m256 stack_vec = simde_mm256_loadu_ps(&tile_node_stack[i]);
// simde__m256 max_vec = simde_mm256_loadu_ps(&stack_max_data[res_idx]);
// simde__m256i max_mask = (simde__m256i)simde_mm256_cmp_ps(
// simde__m256 max_vec =
// simde_mm256_loadu_ps(&stack_max_data[res_idx]); simde__m256i
// max_mask = (simde__m256i)simde_mm256_cmp_ps(
// stack_vec, max_vec, SIMDE_CMP_GT_OQ);
// simde_mm256_maskstore_ps(&stack_max_data[res_idx], max_mask,
// stack_vec); simde_mm256_maskstore_epi32(&stack_max_idx_data[res_idx],
// max_mask,
// stack_vec);
// simde_mm256_maskstore_epi32(&stack_max_idx_data[res_idx], max_mask,
// node_vec);
// }
for (; i < tile_size; i++) {
Expand All @@ -536,29 +481,27 @@ static PyObject *delay_sum_reduce(PyObject *self, PyObject *args,
// stack_snapshot function
static PyObject *delay_sum_snapshot(PyObject *self, PyObject *args,
PyObject *kwargs) {
PyObject *traces, *offsets, *shifts, *weights, *node_mask, *shift_range;
node_mask = Py_None;
PyObject *traces, *offsets, *nodes, *shift_range;
shift_range = Py_None;
int32_t index;

static char *kwlist[] = {"traces", "offsets", "shifts", "weights",
"index", "shift_range", "node_mask", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOOi|OO", kwlist, &traces,
&offsets, &shifts, &weights, &index,
&shift_range, &node_mask)) {
static char *kwlist[] = {"traces", "offsets", "nodes",
"index", "shift_range", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOi|O", kwlist, &traces,
&offsets, &nodes, &index, &shift_range)) {
return NULL;
}

Trace *traces_list;
Node *nodes_list;
int32_t min_shift, max_shift;
if (!prepare(traces, offsets, shifts, weights, node_mask, &traces_list,
&nodes_list, &min_shift, &max_shift)) {
if (!prepare(nodes, traces, offsets, &nodes_list, &traces_list, &min_shift,
&max_shift)) {
return NULL;
}

npy_intp n_traces = PyList_Size(traces);
npy_intp n_nodes = PyArray_SHAPE((PyArrayObject *)shifts)[0];
npy_intp n_nodes = PyList_Size(nodes);
npy_intp stack_size = max_shift - min_shift;

if (shift_range != Py_None) {
Expand Down
Loading
Loading