Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions docs/source/user_guide/flexible_tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,40 @@ m = qd.tensor_mat(2, 2, qd.f32, shape=(3,))
These match the existing `qd.Vector.*` / `qd.Matrix.*` factories one-for-one;
`qd.tensor_vec` / `qd.tensor_mat` simply add the per-tensor `backend=` knob.

Subsequent releases will add a `qd.tensor_annotation(backend)` helper for
kernel argument typing, and a `layout=` keyword for per-tensor physical-memory
## Annotating kernel arguments: `qd.tensor_annotation`

Kernel parameter annotations differ between the two backends — fields use
`qd.template()` and ndarrays use `qd.types.ndarray()`. To avoid sprinkling
`if`/`else` blocks across every kernel signature, pick the annotation **once**
at module load time:

```python
import quadrants as qd

# Choose your run-wide backend in one place.
BACKEND = qd.Backend.NDARRAY
V_ANNOTATION = qd.tensor_annotation(BACKEND)

qd.init(arch=qd.x64)

@qd.kernel
def fill(x: V_ANNOTATION):
for i in range(x.shape[0]):
x[i] = i

a = qd.tensor(qd.i32, shape=(4,), backend=BACKEND)
fill(a)
```

The returned object is interchangeable with its direct equivalent:

| `backend` | `qd.tensor_annotation(backend)` returns | Equivalent to |
|---|---|---|
| `qd.Backend.FIELD` | `qd.template()` instance | `def k(x: qd.template()): ...` |
| `qd.Backend.NDARRAY` | `qd.types.ndarray()` instance | `def k(x: qd.types.ndarray()): ...` |

This mirrors the one-liner Genesis already uses to switch backends; the
helper just makes the pattern first-class.

Subsequent releases will add a `layout=` keyword for per-tensor physical-memory
layout.
1 change: 1 addition & 0 deletions python/quadrants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __getattr__(attr):
"math",
"sparse",
"tensor",
"tensor_annotation",
"tensor_mat",
"tensor_vec",
"tools",
Expand Down
43 changes: 42 additions & 1 deletion python/quadrants/_flexible.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from enum import IntEnum

__all__ = ["Backend", "tensor", "tensor_mat", "tensor_vec"]
__all__ = ["Backend", "tensor", "tensor_annotation", "tensor_mat", "tensor_vec"]


class Backend(IntEnum):
Expand Down Expand Up @@ -150,3 +150,44 @@ def tensor_mat(n, m, dtype, shape, *, backend=Backend.FIELD, **kwargs):
if backend is Backend.NDARRAY:
return Matrix.ndarray(n, m, dtype, shape, **kwargs)
raise AssertionError(f"unhandled Backend member: {backend!r}")


def tensor_annotation(backend):
"""Return the kernel-argument annotation appropriate for ``backend``.

Mirrors the Genesis ``V_ANNOTATION = qd.types.ndarray() if use_ndarray
else qd.template`` pattern as a single first-class call. Use it once, at
module load time, to build a uniform annotation that you then attach to
every tensor kernel argument:

.. code-block:: python

V_ANNOTATION = qd.tensor_annotation(qd.Backend.FIELD)

@qd.kernel
def fill(x: V_ANNOTATION):
for i in qd.ndrange(x.shape[0]):
x[i] = 1.0

Args:
backend (Backend): The backend whose tensors will be passed to
kernels annotated with the returned object.

Returns:
An object suitable for use as a kernel-argument type annotation:

- For ``Backend.FIELD``: an instance of ``qd.template()``.
- For ``Backend.NDARRAY``: an instance of ``qd.types.ndarray()``.

Both forms are interchangeable with their direct equivalents — the
helper just hides the conditional behind one call.
"""
backend = _coerce_backend(backend)
from quadrants import types as _types # late import
from quadrants.types.annotations import template # late import

if backend is Backend.FIELD:
return template()
if backend is Backend.NDARRAY:
return _types.ndarray()
raise AssertionError(f"unhandled Backend member: {backend!r}")
59 changes: 59 additions & 0 deletions tests/python/test_flexible_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Tests for ``qd.tensor_annotation`` (PR 4)."""

import pytest

import quadrants as qd

from tests import test_utils


def test_tensor_annotation_field_returns_template_instance():
ann = qd.tensor_annotation(qd.Backend.FIELD)
assert isinstance(ann, qd.types.annotations.template)


def test_tensor_annotation_ndarray_returns_ndarray_type():
ann = qd.tensor_annotation(qd.Backend.NDARRAY)
direct = qd.types.ndarray()
assert type(ann) is type(direct)


def test_tensor_annotation_invalid_backend_raises():
with pytest.raises(ValueError, match="backend="):
qd.tensor_annotation(99)


def test_tensor_annotation_int_value_accepted():
"""``qd.tensor_annotation(0)`` and ``(1)`` work via IntEnum coercion."""
ann_field = qd.tensor_annotation(0)
ann_ndarray = qd.tensor_annotation(1)
assert isinstance(ann_field, qd.types.annotations.template)
assert type(ann_ndarray) is type(qd.types.ndarray())


@test_utils.test(arch=qd.cpu)
def test_tensor_annotation_field_drives_kernel():
V_ANNOTATION = qd.tensor_annotation(qd.Backend.FIELD)

@qd.kernel
def fill(x: V_ANNOTATION):
for i in range(4):
x[i] = i + 10

a = qd.tensor(qd.i32, shape=(4,))
fill(a)
assert list(a.to_numpy()) == [10, 11, 12, 13]


@test_utils.test(arch=qd.cpu)
def test_tensor_annotation_ndarray_drives_kernel():
V_ANNOTATION = qd.tensor_annotation(qd.Backend.NDARRAY)

@qd.kernel
def fill(x: V_ANNOTATION):
for i in range(4):
x[i] = i + 100

a = qd.tensor(qd.i32, shape=(4,), backend=qd.Backend.NDARRAY)
fill(a)
assert list(a.to_numpy()) == [100, 101, 102, 103]
Loading