Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ bytemuck_derive = { default-features = false, optional = true, version = "1.7.1"
bytes = { default-features = false, optional = true, version = "1.0" }
diesel = { default-features = false, optional = true, version = "2.2.3" }
ndarray = { default-features = false, optional = true, version = "0.15.6" }
num-traits = { default-features = false, features = ["i128"], version = "0.2" }
num-traits = { default-features = false, features = ["i128"], version = "0.2.18" }
postgres-types = { default-features = false, optional = true, version = "0.2" }
proptest = { default-features = false, optional = true, features = ["std"], version = "1.0" }
rand = { default-features = false, optional = true, version = "0.8" }
Expand Down
25 changes: 24 additions & 1 deletion src/arithmetic_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

use crate::{decimal::CalculationResult, ops, Decimal};
use core::ops::{Add, Div, Mul, Rem, Sub};
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedRem, CheckedSub, Inv};
use num_traits::{
CheckedAdd, CheckedDiv, CheckedMul, CheckedRem, CheckedSub, Inv, SaturatingAdd, SaturatingMul, SaturatingSub,
};

// Macros and `Decimal` implementations

Expand Down Expand Up @@ -198,6 +200,27 @@ impl CheckedRem for Decimal {
}
}

impl SaturatingAdd for Decimal {
#[inline]
fn saturating_add(&self, v: &Decimal) -> Decimal {
Decimal::saturating_add(*self, *v)
}
}

impl SaturatingSub for Decimal {
#[inline]
fn saturating_sub(&self, v: &Decimal) -> Decimal {
Decimal::saturating_sub(*self, *v)
}
}

impl SaturatingMul for Decimal {
#[inline]
fn saturating_mul(&self, v: &Decimal) -> Decimal {
Decimal::saturating_mul(*self, *v)
}
}

impl Inv for Decimal {
type Output = Self;

Expand Down
26 changes: 25 additions & 1 deletion src/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use diesel::{deserialize::FromSqlRow, expression::AsExpression, sql_types::Numer
#[allow(unused_imports)] // It's not actually dead code below, but the compiler thinks it is.
#[cfg(not(feature = "std"))]
use num_traits::float::FloatCore;
use num_traits::{FromPrimitive, Num, One, Signed, ToPrimitive, Zero};
use num_traits::{Bounded, ConstOne, ConstZero, FromPrimitive, Num, One, Signed, ToPrimitive, Zero};
#[cfg(feature = "rkyv")]
use rkyv::{Archive, Deserialize, Serialize};
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
Expand Down Expand Up @@ -2041,6 +2041,14 @@ impl One for Decimal {
}
}

impl ConstZero for Decimal {
const ZERO: Self = Decimal::ZERO;
}

impl ConstOne for Decimal {
const ONE: Self = Decimal::ONE;
}

impl Signed for Decimal {
fn abs(&self) -> Self {
self.abs()
Expand Down Expand Up @@ -2083,6 +2091,22 @@ impl Num for Decimal {
}
}

impl Bounded for Decimal {
fn min_value() -> Self {
Decimal::MIN
}

fn max_value() -> Self {
Decimal::MAX
}
}

impl num_traits::NumCast for Decimal {
fn from<T: ToPrimitive>(n: T) -> Option<Self> {
n.to_f64().and_then(<Decimal as FromPrimitive>::from_f64)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other contracts look good, but the implementation of this one I feel is problematic. For example:

        use num_traits::FromPrimitive;
        let n: u64 = 9_007_199_254_740_993;
        assert_eq!(
            <Decimal as n>::from(n),
            Some(Decimal::from_u64(n).expect("Parsed successfully"))
        );

Will fail with:

assertion `left == right` failed
  left: Some(9007199254740992)
 right: Some(9007199254740993)

This feels tricky to justify as a rounding error, given it's an integer.

Another one:

        use num_traits::FromPrimitive;
        let n: i128 = 79_228_162_514_264_337_593_543_950_000; // close to Decimal::MAX
        assert_eq!(
            <Decimal as NumCast>::from(n),
            Some(Decimal::from_i128(n).expect("Parsed successfully"))
        );

Will fail with:

assertion `left == right` failed
  left: None
 right: Some(79228162514264337593543950000)

My concern here is that you get subtle differences depending on how you parse the number, which I would consider as unexpected - given it's target is a Decimal. I think the latter example is validated somewhat in the docs for NumCast:

If the source value cannot be represented by the target type, then None is returned.

I actually think this interface is a really tough one to work with and not really made for numbers outside of the ToPrimitive set. The only solution I can think of that kind of meets in the middle is to use f64 but if there is no fract then fallback to wide integer arithmetic. It wouldn't be the "performant" path, but may reduce some confusion over conversion mistmatches.

e.g. (untested, but something like this)

  impl num_traits::NumCast for Decimal {
      fn from<T: ToPrimitive>(n: T) -> Option<Self> {
          let f = n.to_f64()?;
          if f.is_finite() && f.fract() == 0.0 {
              if let Some(i) = n.to_i128() { return Decimal::from_i128(i); }
              if let Some(u) = n.to_u128() { return Decimal::from_u128(u); }
          }
          Decimal::from_f64(f)
      }
  }

I think the other thing we need here is more test coverage just to make sure we're covering alternate angles.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point.

I think we can drop the second u128 fallback. It should always return None since the extrema of i128 are already outside the bounds of what Decimal can represent.

Either way, I think that this resolves your round-trip concern. I'll update it and include a test suite to prove to myself that it actually does what I expect.

}
}

impl FromStr for Decimal {
type Err = Error;

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub mod prelude {
pub use crate::maths::MathematicalOps;
pub use crate::{Decimal, RoundingStrategy};
pub use core::str::FromStr;
pub use num_traits::{FromPrimitive, One, Signed, ToPrimitive, Zero};
pub use num_traits::{Bounded, FromPrimitive, One, Signed, ToPrimitive, Zero};
}

#[cfg(feature = "macros")]
Expand Down
98 changes: 98 additions & 0 deletions tests/decimal_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4882,3 +4882,101 @@ mod issues {
assert_result(29, a, b);
}
}

mod num_traits_impls {
use core::str::FromStr;
use num_traits::{Bounded, ConstOne, ConstZero, NumCast, SaturatingAdd, SaturatingMul, SaturatingSub};
use rust_decimal::Decimal;

#[test]
fn bounded_min_max() {
assert_eq!(<Decimal as Bounded>::min_value(), Decimal::MIN);
assert_eq!(<Decimal as Bounded>::max_value(), Decimal::MAX);
}

#[test]
fn const_zero_and_one_usable_in_const_context() {
const Z: Decimal = <Decimal as ConstZero>::ZERO;
const O: Decimal = <Decimal as ConstOne>::ONE;
assert_eq!(Z, Decimal::ZERO);
assert_eq!(O, Decimal::ONE);
}

#[test]
fn bounded_in_prelude() {
use rust_decimal::prelude::*;
let _: Decimal = Bounded::min_value();
let _: Decimal = Bounded::max_value();
}

#[test]
fn num_cast_from_integers() {
assert_eq!(<Decimal as NumCast>::from(0i8), Some(Decimal::ZERO));
assert_eq!(<Decimal as NumCast>::from(-1i8), Some(Decimal::NEGATIVE_ONE));
assert_eq!(<Decimal as NumCast>::from(1i32), Some(Decimal::ONE));
assert_eq!(<Decimal as NumCast>::from(42u64), Decimal::from_str("42").ok());
assert_eq!(<Decimal as NumCast>::from(100i128), Decimal::from_str("100").ok());
}

#[test]
fn num_cast_from_floats() {
assert_eq!(<Decimal as NumCast>::from(0.0f32), Some(Decimal::ZERO));
assert_eq!(<Decimal as NumCast>::from(1.5f64), Decimal::from_str("1.5").ok());
assert_eq!(<Decimal as NumCast>::from(-2.25f64), Decimal::from_str("-2.25").ok());
}

#[test]
fn num_cast_rejects_non_finite_and_out_of_range() {
assert_eq!(<Decimal as NumCast>::from(f64::NAN), None);
assert_eq!(<Decimal as NumCast>::from(f64::INFINITY), None);
assert_eq!(<Decimal as NumCast>::from(f64::NEG_INFINITY), None);
// Decimal::MAX is ~7.9e28; 1e30 sits above it.
assert_eq!(<Decimal as NumCast>::from(1.0e30f64), None);
assert_eq!(<Decimal as NumCast>::from(-1.0e30f64), None);
}

#[test]
fn saturating_add_normal_and_overflow() {
let a = Decimal::from_str("1.5").unwrap();
let b = Decimal::from_str("2.5").unwrap();
assert_eq!(SaturatingAdd::saturating_add(&a, &b), Decimal::from_str("4.0").unwrap());
assert_eq!(
SaturatingAdd::saturating_add(&Decimal::MAX, &Decimal::ONE),
Decimal::MAX
);
assert_eq!(
SaturatingAdd::saturating_add(&Decimal::MIN, &Decimal::NEGATIVE_ONE),
Decimal::MIN
);
}

#[test]
fn saturating_sub_normal_and_overflow() {
let a = Decimal::from_str("5").unwrap();
let b = Decimal::from_str("2").unwrap();
assert_eq!(SaturatingSub::saturating_sub(&a, &b), Decimal::from_str("3").unwrap());
assert_eq!(
SaturatingSub::saturating_sub(&Decimal::MIN, &Decimal::ONE),
Decimal::MIN
);
assert_eq!(
SaturatingSub::saturating_sub(&Decimal::MAX, &Decimal::NEGATIVE_ONE),
Decimal::MAX
);
}

#[test]
fn saturating_mul_normal_and_overflow() {
let a = Decimal::from_str("3").unwrap();
let b = Decimal::from_str("4").unwrap();
assert_eq!(SaturatingMul::saturating_mul(&a, &b), Decimal::from_str("12").unwrap());
assert_eq!(
SaturatingMul::saturating_mul(&Decimal::MAX, &Decimal::TWO),
Decimal::MAX
);
assert_eq!(
SaturatingMul::saturating_mul(&Decimal::MIN, &Decimal::TWO),
Decimal::MIN
);
}
}