From 04935aff82e5a0d3682a2ddd7e7d10a1dc1d8561 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Charlotte=20=F0=9F=A6=9D=20Delenk?= Date: Thu, 3 Oct 2024 14:34:52 +0200 Subject: [PATCH] replace division algorithm --- .gitignore | 3 +- src/lib.rs | 1 + src/u256/mod.rs | 346 +++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 283 insertions(+), 67 deletions(-) diff --git a/.gitignore b/.gitignore index 2e1e3ed..d2f38eb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /target /.direnv /result -/resul-bin \ No newline at end of file +/resul-bin +.vscode/ \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index eacd9e1..9258714 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ #![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))] #![no_std] +#[cfg(feature = "u256")] pub mod u256; diff --git a/src/u256/mod.rs b/src/u256/mod.rs index 50af961..fbf71c3 100644 --- a/src/u256/mod.rs +++ b/src/u256/mod.rs @@ -4,14 +4,30 @@ use core::{ cmp::Ordering, fmt::{self, Binary, Display, Formatter, LowerHex, Octal}, iter::{Product, Sum}, - num::{IntErrorKind, ParseIntError}, + num::IntErrorKind, ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, - DivAssign, Mul, MulAssign, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, + DivAssign, Mul, MulAssign, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, + SubAssign, }, str::{self, FromStr}, }; +/// Creates a U256 literal from a string. +#[macro_export] +macro_rules! u256 { + ($lit:literal) => { + u256!(10, $lit) + }; + ($base:literal, $lit:literal) => {{ + const CONSTANT: $crate::u256::U256 = match $crate::u256::U256::from_str_radix($lit, $base) { + Ok(n) => n, + Err(_) => panic!(concat!("Invalid integer constant: ", $lit)), + }; + CONSTANT + }}; +} + /// The 256-bit unsigned integer type. #[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] pub struct U256(pub u128, pub u128); @@ -26,7 +42,7 @@ impl U256 { /// # use extra_math::u256::U256; /// assert_eq!(U256::MIN, 0); /// ``` - pub const MIN: Self = Self(0, 0); + pub const MIN: Self = u256!("0"); /// The largest value that can be represented by this integer type (2^256 − 1). pub const MAX: Self = Self(u128::MAX, u128::MAX); /// The size of this integer type in bits. @@ -44,8 +60,8 @@ impl U256 { /// Basic usage: /// /// ``` - /// # use extra_math::u256::U256; - /// assert_eq!(U256::from(0b01001100u128).count_ones(), 3); + /// # use extra_math::{u256, u256::U256}; + /// assert_eq!(u256!(2, "01001100").count_ones(), 3); /// ``` pub const fn count_ones(self) -> u32 { self.0.count_ones() + self.1.count_ones() @@ -72,7 +88,7 @@ impl U256 { /// /// ``` /// # use extra_math::u256::U256; - /// let n = U256::MAX >> 2; + /// let n = U256::MAX >> 2u8; /// /// assert_eq!(n.leading_zeros(), 2); /// ``` @@ -91,8 +107,8 @@ impl U256 { /// Basic usage: /// /// ``` - /// # use extra_math::u256::U256; - /// let n = U256::from(0b01001100u128); + /// # use extra_math::{u256, u256::U256}; + /// let n = u256!(2, "1001100"); /// /// assert_eq!(n.trailing_zeros(), 2); /// ``` @@ -112,7 +128,7 @@ impl U256 { /// /// ``` /// # use extra_math::u256::U256; - /// let n = !(U256::MAX >> 2); + /// let n = !(U256::MAX >> 2u32); /// /// assert_eq!(n.leading_ones(), 2); /// ``` @@ -131,8 +147,8 @@ impl U256 { /// Basic usage: /// /// ``` - /// # use extra_math::u256::U256; - /// let n = U256::from(0b1010111u128); + /// # use extra_math::{u256, u256::U256}; + /// let n = u256!(2, "1010111"); /// /// assert_eq!(n.trailing_ones(), 3); /// ``` @@ -153,9 +169,9 @@ impl U256 { /// Basic usage: /// /// ``` - /// # use extra_math::u256::U256; + /// # use extra_math::{u256, u256::U256}; /// - /// let n = U256(0x13f40000000000000000000000000000, 0x4f76); + /// let n = u256!(16, "13f4_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_4f76"); /// let m = U256::from(0x4f7613f4u32); /// /// assert_eq!(n.rotate_left(16), m); @@ -403,7 +419,7 @@ impl U256 { /// ``` /// # use extra_math::u256::U256; /// assert_eq!(U256::from(5u8).checked_mul(U256::from(1u8)), Some(U256::from(5u8))); - /// assert_eq!(U256::MAX.checked_sub(U256::from(2u8)), None) + /// assert_eq!(U256::MAX.checked_mul(U256::from(2u8)), None) /// ``` pub const fn checked_mul(self, rhs: Self) -> Option { match self.overflowing_mul(rhs) { @@ -437,50 +453,125 @@ impl U256 { } } + const fn div_rem_256_by_128_to_128_default(u: Self, mut v: u128) -> (u128, u128) { + const N_UDWORD_BITS: u32 = u128::BITS; + const B: u128 = 1 << (N_UDWORD_BITS / 2); // number base + let un1; + let un0; + let vn1; + let vn0; + let mut q1; + let mut q0: u128; + let un64; + let un21: u128; + let un10; + let mut rhat; + let s; + + s = v.leading_zeros(); + if s > 0 { + // Normalize the divisor + v <<= s; + un64 = (u.0 << s) | (u.1 >> (N_UDWORD_BITS - s)); + un10 = u.1 << s; + } else { + // Avoid undefined behavior of (u0 >> 64) + un64 = u.1; + un10 = u.0; + } + + // break divisor up into two 64 bit digits. + vn1 = v >> (N_UDWORD_BITS / 2); + vn0 = (v as u64) as u128; + + // Break right half of dividend into two digits. + un1 = un10 >> (N_UDWORD_BITS / 2); + un0 = (un10 as u64) as u128; + + // Compute the first quotient digit, q1. + q1 = un64 / vn1; + rhat = un64 - q1 * vn1; + + // q1 has at most error 2. No more than 2 iterations. + while q1 >= B || q1 * vn0 > B.wrapping_mul(rhat) + un1 { + q1 -= 1; + rhat += vn1; + if rhat >= B { + break; + } + } + + un21 = un64 + .wrapping_mul(B) + .wrapping_add(un1) + .wrapping_sub(q1.wrapping_mul(v)); + + // Compute the second quotient digit + q0 = un21 / vn1; + rhat = un21 - q0 * vn1; + + // q0 has at most error 2. No more than 2 iterations. + while q0 >= B || q0 * vn0 > B.wrapping_mul(rhat) + un0 { + q0 -= 1; + rhat += vn1; + if rhat >= B { + break; + } + } + + let r = (un21 + .wrapping_mul(B) + .wrapping_add(un0) + .wrapping_sub(q0.wrapping_mul(v))) + >> s; + let q = q1.wrapping_mul(B) + q0; + (q, r) + } + /// Checked integer divrem. Computes `(self / rhs, self % rhs)`, returning `None` if `rhs == 0` - pub const fn checked_div_rem(mut self, rhs: Self) -> Option<(Self, Self)> { + pub const fn checked_div_rem(mut self, mut rhs: Self) -> Option<(Self, Self)> { if rhs.eq(Self(0, 0)) { return None; } + // ported from https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/udivmodti4.c if rhs.gt(self) { - // divisor > numerator - return Some((Self(0, 0), rhs)); - } - if rhs.eq(Self(0, 1)) { - // divisor == 1 - return Some((self, Self(0, 0))); + return Some((Self(0, 0), self)); } - if self.lt(Self(1, 0)) && rhs.lt(Self(1, 0)) { - let div = match self.1.checked_div(rhs.1) { - Some(v) => v, - None => return None, - }; - let rem = match self.1.checked_rem(rhs.1) { - Some(v) => v, - None => return None, - }; - return Some((Self(0, div), Self(0, rem))); - } - - let mut bits = rhs.leading_zeros() - self.leading_zeros() + 1; - let mut rem = self.wrapping_shr(bits); - self = self.wrapping_shl(Self::BITS - bits); - let mut wrap = Self(0, 0); - while bits > 0 { - bits -= 1; - rem = (rem.wrapping_shl(1)).bit_or(self.wrapping_shr(Self::BITS - 1)); - self = (self.wrapping_shl(1)).bit_or(wrap); - if rhs.gt(rem) { - wrap = Self(0, 0); + if rhs.0 == 0 { + let mut remainder = Self(0, 0); + let mut quotient = Self(0, 0); + if self.0 < rhs.1 { + // The result fits in 64 bits. + (quotient.1, remainder.1) = Self::div_rem_256_by_128_to_128_default(self, rhs.1); } else { - wrap = Self(0, 1); - rem = rem.wrapping_sub(rhs); + quotient.0 = self.0 / rhs.1; + self.0 = self.0 % rhs.1; + (quotient.1, remainder.1) = Self::div_rem_256_by_128_to_128_default(self, rhs.1); } + return Some((quotient, remainder)); } - Some((self.wrapping_shl(1).bit_or(wrap), rem)) + let mut shift = (rhs.0.leading_zeros() - self.0.leading_zeros()) as i32; + rhs = rhs.wrapping_shr(shift as u32); + let mut quotient = Self(0, 0); + while shift >= 0 { + quotient.1 <<= 1; + + let carry = if self.ge(rhs) { + self = self.wrapping_sub(rhs); + 1 + } else { + 0 + }; + + quotient.1 |= carry; + rhs = rhs.wrapping_shr(1); + shift -= 1; + } + + Some((quotient, self)) } /// Checked integer division. Computes self / rhs, returning None if rhs == 0. @@ -643,7 +734,7 @@ impl U256 { /// /// ``` /// # use extra_math::u256::U256; - /// assert_eq!(U256::from(5u8).strict_rem_euclid(U256::from(2u8)), U256::from(uu8)); + /// assert_eq!(U256::from(5u8).strict_rem_euclid(U256::from(2u8)), U256::from(1u8)); /// ``` /// /// The following function panics because of division by zero: @@ -854,6 +945,7 @@ impl U256 { /// /// This following panics because of overflow: /// ```should_panic + /// # use extra_math::u256::U256; /// let _ = U256::from(0x10u8).strict_shl(257); /// ``` pub const fn strict_shl(self, rhs: u32) -> Self { @@ -910,11 +1002,12 @@ impl U256 { /// /// ``` /// # use extra_math::u256::U256; - /// assert_eq!(U256::from(0x10u8).strict_shr(4), U256::from(0x11u8)); + /// assert_eq!(U256::from(0x10u8).strict_shr(4), U256::from(0x1u8)); /// ``` /// /// This following panics because of overflow: /// ```should_panic + /// # use extra_math::u256::U256; /// let _ = U256::from(0x10u8).strict_shr(257); /// ``` pub const fn strict_shr(self, rhs: u32) -> Self { @@ -952,7 +1045,7 @@ impl U256 { /// /// ``` /// # use extra_math::u256::U256; - /// assert_eq!(U256::from(2u8).checked_pow(5), Some(U256::from(25u8))); + /// assert_eq!(U256::from(2u8).checked_pow(5), Some(U256::from(32u8))); /// assert_eq!(U256::MAX.checked_pow(2), None); /// ``` pub const fn checked_pow(self, exp: u32) -> Option { @@ -977,14 +1070,14 @@ impl U256 { /// /// ``` /// # use extra_math::u256::U256; - /// assert_eq!(U256::from(2u8).strict_pow(5), U256::from(25u8)); + /// assert_eq!(U256::from(2u8).strict_pow(5), U256::from(32u8)); /// ``` /// /// The following panics because of overflow: /// /// ```should_panic /// # use extra_math::u256::U256; - /// let _ = Self::MAX.strict_pow(2); + /// let _ = U256::MAX.strict_pow(2); /// ``` pub const fn strict_pow(self, exp: u32) -> Self { match self.checked_pow(exp) { @@ -1202,8 +1295,8 @@ impl U256 { /// # use extra_math::u256::U256; /// assert_eq!(U256::from(0u8).wrapping_neg(), U256::from(0u8)); /// assert_eq!(U256::MAX.wrapping_neg(), U256::from(1u8)); - /// assert_eq!(U256::from(13u8).wrapping_neg, !U256::from(13u8) + U256::from(1u8)); - /// assert_eq!(U256::from(42u8).wrapping_neg, !(U256::from(42u8) - U256::from(1u8))); + /// assert_eq!(U256::from(13u8).wrapping_neg(), !U256::from(13u8) + U256::from(1u8)); + /// assert_eq!(U256::from(42u8).wrapping_neg(), !(U256::from(42u8) - U256::from(1u8))); /// ``` pub const fn wrapping_neg(self) -> Self { self.overflowing_neg().0 @@ -1218,7 +1311,7 @@ impl U256 { /// /// ``` /// # use extra_math::u256::U256; - /// assert_eq!(U256::from(1u8).wrapping_shl(7), U256::from(127u8)); + /// assert_eq!(U256::from(1u8).wrapping_shl(7), U256::from(128u8)); /// assert_eq!(U256::from(1u8).wrapping_shl(256), U256::from(1u8)); /// ``` pub const fn wrapping_shl(self, n: u32) -> Self { @@ -1234,7 +1327,7 @@ impl U256 { /// /// ``` /// # use extra_math::u256::U256; - /// assert_eq!(U256::from(127u8).wrapping_shr(7), U256::from(1u8)); + /// assert_eq!(U256::from(127u8).wrapping_shr(7), U256::from(0u8)); /// assert_eq!(U256::from(127u8).wrapping_shl(256), U256::from(127u8)); /// ``` pub const fn wrapping_shr(self, n: u32) -> Self { @@ -1387,6 +1480,11 @@ impl U256 { self.0 > rhs.0 || self.1 > rhs.1 } + /// Constant version of `self >= rhs` + pub const fn ge(self, rhs: Self) -> bool { + self.gt(rhs) || self.eq(rhs) + } + /// Constant version of `self < rhs` pub const fn lt(self, rhs: Self) -> bool { self.0 < rhs.0 || self.1 < rhs.1 @@ -1417,7 +1515,6 @@ impl U256 { mul_high_high .wrapping_add(mul_high_low) .wrapping_add(mul_low_high) - .wrapping_add(mul_high_low) .wrapping_add(mul_low_low) } @@ -1434,7 +1531,7 @@ impl U256 { /// ``` /// # use extra_math::u256::U256; /// assert_eq!(U256::from(5u8).overflowing_mul(U256::from(2u8)), (U256::from(10u8), false)); - /// assert_eq!(U256::MAX.overflowing_mul(U256::from(2u8)), (U256::MAX - U256::from(1), true)); + /// assert_eq!(U256::MAX.overflowing_mul(U256::from(2u8)), (U256::MAX - U256::from(1u8), true)); /// ``` pub const fn overflowing_mul(self, rhs: Self) -> (Self, bool) { let (l, h) = self.widening_mul(rhs); @@ -1544,11 +1641,16 @@ impl U256 { /// ``` pub const fn overflowing_shl(self, mut n: u32) -> (Self, bool) { let overflow = n >= Self::BITS; + n &= Self::BITS - 1; + if n >= (Self::BITS / 2) { + return (Self(self.1.wrapping_shl(n), 0), overflow); + } + let mut hi = self.0.wrapping_shl(n); - hi |= self.1.wrapping_shr(256 - n); + hi |= self.1.wrapping_shr(127 - n); (Self(hi, self.1.wrapping_shl(n)), overflow) } @@ -1567,10 +1669,17 @@ impl U256 { /// ``` pub const fn overflowing_shr(self, mut n: u32) -> (Self, bool) { let overflow = n >= Self::BITS; + n &= Self::BITS - 1; + if n >= (Self::BITS / 2) { + return (Self(0, self.0.wrapping_shr(n)), overflow); + } + let mut lo = self.1.wrapping_shr(n); - lo |= self.0.wrapping_shl(256 - n); + + lo |= self.0.wrapping_shl(127 - n); + (Self(self.0.wrapping_shr(n), lo), overflow) } @@ -2029,7 +2138,7 @@ impl U256 { /// ``` /// # use extra_math::u256::U256; /// assert_eq!(U256::from(5u8).widening_mul(U256::from(2u8)), (U256::from(10u8), U256::from(0u8))); - /// assert_eq!(U256::MAX.widening_mul(U256::from(2u8)), (U256::MAX - U256::from(1), U256::from(1u8))); + /// assert_eq!(U256::MAX.widening_mul(U256::from(2u8)), (U256::MAX - U256::from(1u8), U256::from(1u8))); /// ``` pub const fn widening_mul(self, rhs: Self) -> (Self, Self) { let mul_high_high = Self::upcasting_mul128(self.0, rhs.0); @@ -2113,6 +2222,11 @@ impl U256 { let mut acc = Self(0, 0); let mut i = 0; while i < digits.len() { + if digits[i] == b'_' { + // Skip _ separators + i += 1; + continue; + } acc = match acc.checked_mul(Self(0, radix as u128)) { Some(n) => n, None => return Err(PosOverflow), @@ -2181,7 +2295,7 @@ impl U256 { } /// Formats the integer in a given base - fn fmt_base(self, f: &mut Formatter<'_>, base: u32, uppercase: bool) -> fmt::Result { + fn fmt_base(mut self, f: &mut Formatter<'_>, base: u32, uppercase: bool) -> fmt::Result { let char_table = if uppercase { [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', @@ -2217,12 +2331,13 @@ impl U256 { let mut curr = buf.len(); for byte in buf.iter_mut().rev() { - let (n, x) = self.checked_div_rem(Self::from(base)).unwrap_or_default(); + let (x, n) = self.checked_div_rem(Self::from(base)).unwrap_or_default(); *byte = char_table[n.1 as usize] as u8; curr -= 1; if x == Self(0, 0) { break; } + self = x; } let buf = &buf[curr..]; @@ -2636,6 +2751,54 @@ impl Sum for U256 { } } +impl Sub<&U256> for &U256 { + type Output = ::Output; + + fn sub(self, rhs: &U256) -> Self::Output { + *self - *rhs + } +} + +impl Sub<&U256> for U256 { + type Output = ::Output; + + fn sub(self, rhs: &U256) -> Self::Output { + self - *rhs + } +} + +impl Sub for &'_ U256 { + type Output = ::Output; + + fn sub(self, rhs: U256) -> Self::Output { + *self - rhs + } +} + +impl Sub for U256 { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + if cfg!(overflow_checks_stable) { + self.strict_sub(rhs) + } else { + self.wrapping_sub(rhs) + } + } +} + +impl SubAssign<&U256> for U256 { + fn sub_assign(&mut self, rhs: &U256) { + *self = *self - rhs; + } +} + +impl SubAssign for U256 { + fn sub_assign(&mut self, rhs: U256) { + *self = *self - rhs; + } +} + impl<'a> Product<&'a U256> for U256 { fn product>(iter: I) -> Self { iter.fold(Self(0, 1), |a, b| a * b) @@ -2802,7 +2965,15 @@ impl Shl for U256 { type Output = ::Output; fn shl(self, rhs: u16) -> Self::Output { - self << u32::from(rhs) + if cfg!(overflow_checks_stable) { + if rhs as u32 >= Self::BITS { + panic!("Tried to shift with overflow: {self} << {rhs}"); + } else { + self.strict_shl(rhs as u32) + } + } else { + self.wrapping_shl(rhs as u32) + } } } @@ -2818,7 +2989,15 @@ impl Shl for U256 { type Output = ::Output; fn shl(self, rhs: u8) -> Self::Output { - self << u32::from(rhs) + if cfg!(overflow_checks_stable) { + if rhs as u32 >= Self::BITS { + panic!("Tried to shift with overflow: {self} << {rhs}"); + } else { + self.strict_shl(rhs as u32) + } + } else { + self.wrapping_shl(rhs as u32) + } } } @@ -3002,7 +3181,15 @@ impl Shr for U256 { type Output = ::Output; fn shr(self, rhs: u16) -> Self::Output { - self >> U256::from(rhs) + if cfg!(overflow_checks_stable) { + if rhs as u32 >= Self::BITS { + panic!("Tried to shift with overflow: {self} << {rhs}"); + } else { + self.strict_shr(rhs as u32) + } + } else { + self.wrapping_shr(rhs as u32) + } } } @@ -3018,7 +3205,15 @@ impl Shr for U256 { type Output = ::Output; fn shr(self, rhs: u8) -> Self::Output { - self >> U256::from(rhs) + if cfg!(overflow_checks_stable) { + if rhs as u32 >= Self::BITS { + panic!("Tried to shift with overflow: {self} << {rhs}"); + } else { + self.strict_shr(rhs as u32) + } + } else { + self.wrapping_shr(rhs as u32) + } } } @@ -3101,3 +3296,22 @@ impl<'a> ShrAssign<&'a u8> for U256 { *self = *self >> rhs; } } + +#[cfg(test)] +pub mod tests { + use super::U256; + extern crate std; + + #[test] + pub fn check_number_parsing() { + let x = U256::from_str_radix( + "13f4_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_4f76", + 16, + ) + .unwrap(); + assert_eq!( + std::format!("{x}"), + "9025054806887987155511483633794827778846787745240661638085164476202701836150" + ); + } +}