Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/pycarl/typed_core/rational.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,67 @@
#include <carl/numbers/conversion/cln_gmp.h>
#include <carl/numbers/numbers.h>

struct PyFraction {
py::object obj;
};

namespace pybind11 {
namespace detail {
template<>
struct type_caster<PyFraction> {
PYBIND11_TYPE_CASTER(PyFraction, const_name("fractions.Fraction"));
bool load(handle src, bool) {
auto fractions = py::module_::import("fractions");
if (!py::isinstance(src, fractions.attr("Fraction")))
return false;
value.obj = py::reinterpret_borrow<py::object>(src);
return true;
}
static handle cast(PyFraction, return_value_policy, handle) {
return py::none().release();
}
};
} // namespace detail
} // namespace pybind11

#ifdef PYCARL_USE_CLN
static cln::cl_I pyint_to_cl_I(py::int_ val) {
bool negative = PyObject_RichCompareBool(val.ptr(), py::int_(0).ptr(), Py_LT) == 1;
py::int_ absval = negative ? py::reinterpret_steal<py::int_>(PyNumber_Negative(val.ptr())) : val;
int bit_len = absval.attr("bit_length")().cast<int>();
if (bit_len == 0)
return cln::cl_I(0);
std::size_t byte_len = (static_cast<std::size_t>(bit_len) + 7) / 8;
py::bytes raw = absval.attr("to_bytes")(py::int_(byte_len), py::str("big"));
const auto* data = reinterpret_cast<const unsigned char*>(PyBytes_AS_STRING(raw.ptr()));

cln::cl_I result(0);
std::size_t i = 0;
for (; i + 4 <= byte_len; i += 4) {
uint32_t chunk = ((uint32_t)data[i] << 24) | ((uint32_t)data[i + 1] << 16) | ((uint32_t)data[i + 2] << 8) | (uint32_t)data[i + 3];
result = cln::ash(result, 32) + cln::cl_I((unsigned int)chunk);
}
for (; i < byte_len; i++) result = cln::ash(result, 8) + cln::cl_I((unsigned int)data[i]);

return negative ? -result : result;
}
#endif

[[maybe_unused]] static mpz_class pyint_to_mpz(py::int_ val) {
bool negative = PyObject_RichCompareBool(val.ptr(), py::int_(0).ptr(), Py_LT) == 1;
py::int_ absval = negative ? py::reinterpret_steal<py::int_>(PyNumber_Negative(val.ptr())) : val;
int bit_len = absval.attr("bit_length")().cast<int>();
if (bit_len == 0)
return mpz_class(0);
std::size_t byte_len = (static_cast<std::size_t>(bit_len) + 7) / 8;
py::bytes raw = absval.attr("to_bytes")(py::int_(byte_len), py::str("big"));
mpz_class result;
mpz_import(result.get_mpz_t(), static_cast<std::size_t>(PyBytes_GET_SIZE(raw.ptr())), 1, 1, 0, 0, PyBytes_AS_STRING(raw.ptr()));
if (negative)
return -result;
return result;
}

void define_cln_rational(py::module& m) {
#ifdef PYCARL_USE_CLN
py::class_<cln::cl_RA>(m, "Rational", "Class wrapping cln-rational numbers")
Expand All @@ -24,6 +85,11 @@ void define_cln_rational(py::module& m) {
return tmp;
}))
.def(py::init(&carl::convert<mpq_class, cln::cl_RA>))
.def(py::init([](PyFraction frac) {
cln::cl_I num = pyint_to_cl_I(frac.obj.attr("numerator").cast<py::int_>());
cln::cl_I den = pyint_to_cl_I(frac.obj.attr("denominator").cast<py::int_>());
return num / den;
}))

.def("__add__", [](const cln::cl_RA& lhs, const cln::cl_RA& rhs) -> cln::cl_RA { return lhs + rhs; })
.def("__add__", [](const cln::cl_RA& lhs, carl::sint rhs) -> cln::cl_RA { return lhs + carl::rationalize<cln::cl_RA>(rhs); })
Expand Down Expand Up @@ -138,6 +204,11 @@ void define_gmp_rational(py::module& m) {
#ifdef PYCARL_HAS_CLN
.def(py::init(&carl::convert<cln::cl_RA, mpq_class>))
#endif
.def(py::init([](PyFraction frac) {
mpz_class num = pyint_to_mpz(frac.obj.attr("numerator").cast<py::int_>());
mpz_class den = pyint_to_mpz(frac.obj.attr("denominator").cast<py::int_>());
return mpq_class(num, den);
}))

.def("__add__", [](const mpq_class& lhs, const mpq_class& rhs) -> mpq_class { return lhs + rhs; })
.def("__add__", [](const mpq_class& lhs, carl::sint rhs) -> mpq_class { return lhs + carl::rationalize<mpq_class>(rhs); })
Expand Down Expand Up @@ -229,4 +300,4 @@ void define_gmp_rational(py::module& m) {

py::implicitly_convertible<carl::uint, mpq_class>();
#endif
}
}
46 changes: 46 additions & 0 deletions tests/pycarl/core/test_rational.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from stormpy import pycarl
import math
import fractions
from configurations import PackageSelector


Expand Down Expand Up @@ -54,6 +55,51 @@ def test_eq(self, package):
r2 = package.Rational("1/2")
assert r3 != r2

def test_from_fraction(self, package):
# zero
assert package.Rational(fractions.Fraction(0)) == package.Rational(0)

# one and minus one
assert package.Rational(fractions.Fraction(1)) == package.Rational(1)
assert package.Rational(fractions.Fraction(-1)) == package.Rational(-1)

# small positive
r = package.Rational(fractions.Fraction(1, 2))
assert package.numerator(r) == 1
assert package.denominator(r) == 2

# small negative
assert package.Rational(fractions.Fraction(-7, 3)) == package.Rational("-7/3")

# denominator 1 (integer-valued)
assert package.Rational(fractions.Fraction(42)) == package.Rational(42)

# fractions.Fraction reduces automatically, verify we preserve the reduced form
r = package.Rational(fractions.Fraction(6, 4)) # reduces to 3/2
assert package.numerator(r) == 3
assert package.denominator(r) == 2

# fits in 32 bits
assert package.Rational(fractions.Fraction(2**31 - 1, 2**31)) == package.Rational(str(2**31 - 1) + "/" + str(2**31))

# fits in 64 bits but not 32
assert package.Rational(fractions.Fraction(2**63 - 1, 2**63)) == package.Rational(str(2**63 - 1) + "/" + str(2**63))

# just beyond 64 bits
n, d = 2**65 + 1, 2**65 + 3
r = package.Rational(fractions.Fraction(n, d))
assert package.numerator(r) == package.Integer(str(n))
assert package.denominator(r) == package.Integer(str(d))

# large (beyond 128 bits), negative numerator
n, d = -(10**40 + 7), 10**40 + 9
r = package.Rational(fractions.Fraction(n, d))
assert package.numerator(r) == package.Integer(str(n))
assert package.denominator(r) == package.Integer(str(d))

# cross-check: round-trip via float for a simple fraction
assert abs(float(package.Rational(fractions.Fraction(1, 3))) - 1 / 3) < 1e-15

def test_comparison_infinity(self, package):
r4 = package.Rational("1/2")
assert pycarl.inf > r4
Expand Down
Loading