WIP: rewrite #7

Draft
darkkirb wants to merge 3 commits from rewrite into main
3 changed files with 283 additions and 67 deletions
Showing only changes of commit 04935aff82 - Show all commits

1
.gitignore vendored
View file

@ -2,3 +2,4 @@
/.direnv
/result
/resul-bin
.vscode/

View file

@ -1,4 +1,5 @@
#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
#![no_std]
#[cfg(feature = "u256")]
pub mod u256;

View file

@ -4,14 +4,30 @@ use core::{
cmp::Ordering,
fmt::{self, Binary, Display, Formatter, LowerHex, Octal},
iter::{Product, Sum},
num::{IntErrorKind, ParseIntError},
num::IntErrorKind,
ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
DivAssign, Mul, MulAssign, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign,
DivAssign, Mul, MulAssign, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
SubAssign,
},
str::{self, FromStr},
};
/// Creates a U256 literal from a string.
#[macro_export]
macro_rules! u256 {
($lit:literal) => {
u256!(10, $lit)
};
($base:literal, $lit:literal) => {{
const CONSTANT: $crate::u256::U256 = match $crate::u256::U256::from_str_radix($lit, $base) {
Ok(n) => n,
Err(_) => panic!(concat!("Invalid integer constant: ", $lit)),
};
CONSTANT
}};
}
/// The 256-bit unsigned integer type.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)]
pub struct U256(pub u128, pub u128);
@ -26,7 +42,7 @@ impl U256 {
/// # use extra_math::u256::U256;
/// assert_eq!(U256::MIN, 0);
/// ```
pub const MIN: Self = Self(0, 0);
pub const MIN: Self = u256!("0");
/// The largest value that can be represented by this integer type (2^256 1).
pub const MAX: Self = Self(u128::MAX, u128::MAX);
/// The size of this integer type in bits.
@ -44,8 +60,8 @@ impl U256 {
/// Basic usage:
///
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(0b01001100u128).count_ones(), 3);
/// # use extra_math::{u256, u256::U256};
/// assert_eq!(u256!(2, "01001100").count_ones(), 3);
/// ```
pub const fn count_ones(self) -> u32 {
self.0.count_ones() + self.1.count_ones()
@ -72,7 +88,7 @@ impl U256 {
///
/// ```
/// # use extra_math::u256::U256;
/// let n = U256::MAX >> 2;
/// let n = U256::MAX >> 2u8;
///
/// assert_eq!(n.leading_zeros(), 2);
/// ```
@ -91,8 +107,8 @@ impl U256 {
/// Basic usage:
///
/// ```
/// # use extra_math::u256::U256;
/// let n = U256::from(0b01001100u128);
/// # use extra_math::{u256, u256::U256};
/// let n = u256!(2, "1001100");
///
/// assert_eq!(n.trailing_zeros(), 2);
/// ```
@ -112,7 +128,7 @@ impl U256 {
///
/// ```
/// # use extra_math::u256::U256;
/// let n = !(U256::MAX >> 2);
/// let n = !(U256::MAX >> 2u32);
///
/// assert_eq!(n.leading_ones(), 2);
/// ```
@ -131,8 +147,8 @@ impl U256 {
/// Basic usage:
///
/// ```
/// # use extra_math::u256::U256;
/// let n = U256::from(0b1010111u128);
/// # use extra_math::{u256, u256::U256};
/// let n = u256!(2, "1010111");
///
/// assert_eq!(n.trailing_ones(), 3);
/// ```
@ -153,9 +169,9 @@ impl U256 {
/// Basic usage:
///
/// ```
/// # use extra_math::u256::U256;
/// # use extra_math::{u256, u256::U256};
///
/// let n = U256(0x13f40000000000000000000000000000, 0x4f76);
/// let n = u256!(16, "13f4_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_4f76");
/// let m = U256::from(0x4f7613f4u32);
///
/// assert_eq!(n.rotate_left(16), m);
@ -403,7 +419,7 @@ impl U256 {
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(5u8).checked_mul(U256::from(1u8)), Some(U256::from(5u8)));
/// assert_eq!(U256::MAX.checked_sub(U256::from(2u8)), None)
/// assert_eq!(U256::MAX.checked_mul(U256::from(2u8)), None)
/// ```
pub const fn checked_mul(self, rhs: Self) -> Option<Self> {
match self.overflowing_mul(rhs) {
@ -437,50 +453,125 @@ impl U256 {
}
}
const fn div_rem_256_by_128_to_128_default(u: Self, mut v: u128) -> (u128, u128) {
const N_UDWORD_BITS: u32 = u128::BITS;
const B: u128 = 1 << (N_UDWORD_BITS / 2); // number base
let un1;
let un0;
let vn1;
let vn0;
let mut q1;
let mut q0: u128;
let un64;
let un21: u128;
let un10;
let mut rhat;
let s;
s = v.leading_zeros();
if s > 0 {
// Normalize the divisor
v <<= s;
un64 = (u.0 << s) | (u.1 >> (N_UDWORD_BITS - s));
un10 = u.1 << s;
} else {
// Avoid undefined behavior of (u0 >> 64)
un64 = u.1;
un10 = u.0;
}
// break divisor up into two 64 bit digits.
vn1 = v >> (N_UDWORD_BITS / 2);
vn0 = (v as u64) as u128;
// Break right half of dividend into two digits.
un1 = un10 >> (N_UDWORD_BITS / 2);
un0 = (un10 as u64) as u128;
// Compute the first quotient digit, q1.
q1 = un64 / vn1;
rhat = un64 - q1 * vn1;
// q1 has at most error 2. No more than 2 iterations.
while q1 >= B || q1 * vn0 > B.wrapping_mul(rhat) + un1 {
q1 -= 1;
rhat += vn1;
if rhat >= B {
break;
}
}
un21 = un64
.wrapping_mul(B)
.wrapping_add(un1)
.wrapping_sub(q1.wrapping_mul(v));
// Compute the second quotient digit
q0 = un21 / vn1;
rhat = un21 - q0 * vn1;
// q0 has at most error 2. No more than 2 iterations.
while q0 >= B || q0 * vn0 > B.wrapping_mul(rhat) + un0 {
q0 -= 1;
rhat += vn1;
if rhat >= B {
break;
}
}
let r = (un21
.wrapping_mul(B)
.wrapping_add(un0)
.wrapping_sub(q0.wrapping_mul(v)))
>> s;
let q = q1.wrapping_mul(B) + q0;
(q, r)
}
/// Checked integer divrem. Computes `(self / rhs, self % rhs)`, returning `None` if `rhs == 0`
pub const fn checked_div_rem(mut self, rhs: Self) -> Option<(Self, Self)> {
pub const fn checked_div_rem(mut self, mut rhs: Self) -> Option<(Self, Self)> {
if rhs.eq(Self(0, 0)) {
return None;
}
// ported from https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/udivmodti4.c
if rhs.gt(self) {
// divisor > numerator
return Some((Self(0, 0), rhs));
}
if rhs.eq(Self(0, 1)) {
// divisor == 1
return Some((self, Self(0, 0)));
return Some((Self(0, 0), self));
}
if self.lt(Self(1, 0)) && rhs.lt(Self(1, 0)) {
let div = match self.1.checked_div(rhs.1) {
Some(v) => v,
None => return None,
};
let rem = match self.1.checked_rem(rhs.1) {
Some(v) => v,
None => return None,
};
return Some((Self(0, div), Self(0, rem)));
}
let mut bits = rhs.leading_zeros() - self.leading_zeros() + 1;
let mut rem = self.wrapping_shr(bits);
self = self.wrapping_shl(Self::BITS - bits);
let mut wrap = Self(0, 0);
while bits > 0 {
bits -= 1;
rem = (rem.wrapping_shl(1)).bit_or(self.wrapping_shr(Self::BITS - 1));
self = (self.wrapping_shl(1)).bit_or(wrap);
if rhs.gt(rem) {
wrap = Self(0, 0);
if rhs.0 == 0 {
let mut remainder = Self(0, 0);
let mut quotient = Self(0, 0);
if self.0 < rhs.1 {
// The result fits in 64 bits.
(quotient.1, remainder.1) = Self::div_rem_256_by_128_to_128_default(self, rhs.1);
} else {
wrap = Self(0, 1);
rem = rem.wrapping_sub(rhs);
quotient.0 = self.0 / rhs.1;
self.0 = self.0 % rhs.1;
(quotient.1, remainder.1) = Self::div_rem_256_by_128_to_128_default(self, rhs.1);
}
return Some((quotient, remainder));
}
Some((self.wrapping_shl(1).bit_or(wrap), rem))
let mut shift = (rhs.0.leading_zeros() - self.0.leading_zeros()) as i32;
rhs = rhs.wrapping_shr(shift as u32);
let mut quotient = Self(0, 0);
while shift >= 0 {
quotient.1 <<= 1;
let carry = if self.ge(rhs) {
self = self.wrapping_sub(rhs);
1
} else {
0
};
quotient.1 |= carry;
rhs = rhs.wrapping_shr(1);
shift -= 1;
}
Some((quotient, self))
}
/// Checked integer division. Computes self / rhs, returning None if rhs == 0.
@ -643,7 +734,7 @@ impl U256 {
///
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(5u8).strict_rem_euclid(U256::from(2u8)), U256::from(uu8));
/// assert_eq!(U256::from(5u8).strict_rem_euclid(U256::from(2u8)), U256::from(1u8));
/// ```
///
/// The following function panics because of division by zero:
@ -854,6 +945,7 @@ impl U256 {
///
/// This following panics because of overflow:
/// ```should_panic
/// # use extra_math::u256::U256;
/// let _ = U256::from(0x10u8).strict_shl(257);
/// ```
pub const fn strict_shl(self, rhs: u32) -> Self {
@ -910,11 +1002,12 @@ impl U256 {
///
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(0x10u8).strict_shr(4), U256::from(0x11u8));
/// assert_eq!(U256::from(0x10u8).strict_shr(4), U256::from(0x1u8));
/// ```
///
/// This following panics because of overflow:
/// ```should_panic
/// # use extra_math::u256::U256;
/// let _ = U256::from(0x10u8).strict_shr(257);
/// ```
pub const fn strict_shr(self, rhs: u32) -> Self {
@ -952,7 +1045,7 @@ impl U256 {
///
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(2u8).checked_pow(5), Some(U256::from(25u8)));
/// assert_eq!(U256::from(2u8).checked_pow(5), Some(U256::from(32u8)));
/// assert_eq!(U256::MAX.checked_pow(2), None);
/// ```
pub const fn checked_pow(self, exp: u32) -> Option<Self> {
@ -977,14 +1070,14 @@ impl U256 {
///
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(2u8).strict_pow(5), U256::from(25u8));
/// assert_eq!(U256::from(2u8).strict_pow(5), U256::from(32u8));
/// ```
///
/// The following panics because of overflow:
///
/// ```should_panic
/// # use extra_math::u256::U256;
/// let _ = Self::MAX.strict_pow(2);
/// let _ = U256::MAX.strict_pow(2);
/// ```
pub const fn strict_pow(self, exp: u32) -> Self {
match self.checked_pow(exp) {
@ -1202,8 +1295,8 @@ impl U256 {
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(0u8).wrapping_neg(), U256::from(0u8));
/// assert_eq!(U256::MAX.wrapping_neg(), U256::from(1u8));
/// assert_eq!(U256::from(13u8).wrapping_neg, !U256::from(13u8) + U256::from(1u8));
/// assert_eq!(U256::from(42u8).wrapping_neg, !(U256::from(42u8) - U256::from(1u8)));
/// assert_eq!(U256::from(13u8).wrapping_neg(), !U256::from(13u8) + U256::from(1u8));
/// assert_eq!(U256::from(42u8).wrapping_neg(), !(U256::from(42u8) - U256::from(1u8)));
/// ```
pub const fn wrapping_neg(self) -> Self {
self.overflowing_neg().0
@ -1218,7 +1311,7 @@ impl U256 {
///
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(1u8).wrapping_shl(7), U256::from(127u8));
/// assert_eq!(U256::from(1u8).wrapping_shl(7), U256::from(128u8));
/// assert_eq!(U256::from(1u8).wrapping_shl(256), U256::from(1u8));
/// ```
pub const fn wrapping_shl(self, n: u32) -> Self {
@ -1234,7 +1327,7 @@ impl U256 {
///
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(127u8).wrapping_shr(7), U256::from(1u8));
/// assert_eq!(U256::from(127u8).wrapping_shr(7), U256::from(0u8));
/// assert_eq!(U256::from(127u8).wrapping_shl(256), U256::from(127u8));
/// ```
pub const fn wrapping_shr(self, n: u32) -> Self {
@ -1387,6 +1480,11 @@ impl U256 {
self.0 > rhs.0 || self.1 > rhs.1
}
/// Constant version of `self >= rhs`
pub const fn ge(self, rhs: Self) -> bool {
self.gt(rhs) || self.eq(rhs)
}
/// Constant version of `self < rhs`
pub const fn lt(self, rhs: Self) -> bool {
self.0 < rhs.0 || self.1 < rhs.1
@ -1417,7 +1515,6 @@ impl U256 {
mul_high_high
.wrapping_add(mul_high_low)
.wrapping_add(mul_low_high)
.wrapping_add(mul_high_low)
.wrapping_add(mul_low_low)
}
@ -1434,7 +1531,7 @@ impl U256 {
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(5u8).overflowing_mul(U256::from(2u8)), (U256::from(10u8), false));
/// assert_eq!(U256::MAX.overflowing_mul(U256::from(2u8)), (U256::MAX - U256::from(1), true));
/// assert_eq!(U256::MAX.overflowing_mul(U256::from(2u8)), (U256::MAX - U256::from(1u8), true));
/// ```
pub const fn overflowing_mul(self, rhs: Self) -> (Self, bool) {
let (l, h) = self.widening_mul(rhs);
@ -1544,11 +1641,16 @@ impl U256 {
/// ```
pub const fn overflowing_shl(self, mut n: u32) -> (Self, bool) {
let overflow = n >= Self::BITS;
n &= Self::BITS - 1;
if n >= (Self::BITS / 2) {
return (Self(self.1.wrapping_shl(n), 0), overflow);
}
let mut hi = self.0.wrapping_shl(n);
hi |= self.1.wrapping_shr(256 - n);
hi |= self.1.wrapping_shr(127 - n);
(Self(hi, self.1.wrapping_shl(n)), overflow)
}
@ -1567,10 +1669,17 @@ impl U256 {
/// ```
pub const fn overflowing_shr(self, mut n: u32) -> (Self, bool) {
let overflow = n >= Self::BITS;
n &= Self::BITS - 1;
if n >= (Self::BITS / 2) {
return (Self(0, self.0.wrapping_shr(n)), overflow);
}
let mut lo = self.1.wrapping_shr(n);
lo |= self.0.wrapping_shl(256 - n);
lo |= self.0.wrapping_shl(127 - n);
(Self(self.0.wrapping_shr(n), lo), overflow)
}
@ -2029,7 +2138,7 @@ impl U256 {
/// ```
/// # use extra_math::u256::U256;
/// assert_eq!(U256::from(5u8).widening_mul(U256::from(2u8)), (U256::from(10u8), U256::from(0u8)));
/// assert_eq!(U256::MAX.widening_mul(U256::from(2u8)), (U256::MAX - U256::from(1), U256::from(1u8)));
/// assert_eq!(U256::MAX.widening_mul(U256::from(2u8)), (U256::MAX - U256::from(1u8), U256::from(1u8)));
/// ```
pub const fn widening_mul(self, rhs: Self) -> (Self, Self) {
let mul_high_high = Self::upcasting_mul128(self.0, rhs.0);
@ -2113,6 +2222,11 @@ impl U256 {
let mut acc = Self(0, 0);
let mut i = 0;
while i < digits.len() {
if digits[i] == b'_' {
// Skip _ separators
i += 1;
continue;
}
acc = match acc.checked_mul(Self(0, radix as u128)) {
Some(n) => n,
None => return Err(PosOverflow),
@ -2181,7 +2295,7 @@ impl U256 {
}
/// Formats the integer in a given base
fn fmt_base(self, f: &mut Formatter<'_>, base: u32, uppercase: bool) -> fmt::Result {
fn fmt_base(mut self, f: &mut Formatter<'_>, base: u32, uppercase: bool) -> fmt::Result {
let char_table = if uppercase {
[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F',
@ -2217,12 +2331,13 @@ impl U256 {
let mut curr = buf.len();
for byte in buf.iter_mut().rev() {
let (n, x) = self.checked_div_rem(Self::from(base)).unwrap_or_default();
let (x, n) = self.checked_div_rem(Self::from(base)).unwrap_or_default();
*byte = char_table[n.1 as usize] as u8;
curr -= 1;
if x == Self(0, 0) {
break;
}
self = x;
}
let buf = &buf[curr..];
@ -2636,6 +2751,54 @@ impl Sum for U256 {
}
}
impl Sub<&U256> for &U256 {
type Output = <U256 as Sub>::Output;
fn sub(self, rhs: &U256) -> Self::Output {
*self - *rhs
}
}
impl Sub<&U256> for U256 {
type Output = <U256 as Sub>::Output;
fn sub(self, rhs: &U256) -> Self::Output {
self - *rhs
}
}
impl Sub<U256> for &'_ U256 {
type Output = <U256 as Sub>::Output;
fn sub(self, rhs: U256) -> Self::Output {
*self - rhs
}
}
impl Sub for U256 {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
if cfg!(overflow_checks_stable) {
self.strict_sub(rhs)
} else {
self.wrapping_sub(rhs)
}
}
}
impl SubAssign<&U256> for U256 {
fn sub_assign(&mut self, rhs: &U256) {
*self = *self - rhs;
}
}
impl SubAssign for U256 {
fn sub_assign(&mut self, rhs: U256) {
*self = *self - rhs;
}
}
impl<'a> Product<&'a U256> for U256 {
fn product<I: Iterator<Item = &'a U256>>(iter: I) -> Self {
iter.fold(Self(0, 1), |a, b| a * b)
@ -2802,7 +2965,15 @@ impl Shl<u16> for U256 {
type Output = <Self as Shl>::Output;
fn shl(self, rhs: u16) -> Self::Output {
self << u32::from(rhs)
if cfg!(overflow_checks_stable) {
if rhs as u32 >= Self::BITS {
panic!("Tried to shift with overflow: {self} << {rhs}");
} else {
self.strict_shl(rhs as u32)
}
} else {
self.wrapping_shl(rhs as u32)
}
}
}
@ -2818,7 +2989,15 @@ impl Shl<u8> for U256 {
type Output = <Self as Shl>::Output;
fn shl(self, rhs: u8) -> Self::Output {
self << u32::from(rhs)
if cfg!(overflow_checks_stable) {
if rhs as u32 >= Self::BITS {
panic!("Tried to shift with overflow: {self} << {rhs}");
} else {
self.strict_shl(rhs as u32)
}
} else {
self.wrapping_shl(rhs as u32)
}
}
}
@ -3002,7 +3181,15 @@ impl Shr<u16> for U256 {
type Output = <Self as Shr>::Output;
fn shr(self, rhs: u16) -> Self::Output {
self >> U256::from(rhs)
if cfg!(overflow_checks_stable) {
if rhs as u32 >= Self::BITS {
panic!("Tried to shift with overflow: {self} << {rhs}");
} else {
self.strict_shr(rhs as u32)
}
} else {
self.wrapping_shr(rhs as u32)
}
}
}
@ -3018,7 +3205,15 @@ impl Shr<u8> for U256 {
type Output = <Self as Shr>::Output;
fn shr(self, rhs: u8) -> Self::Output {
self >> U256::from(rhs)
if cfg!(overflow_checks_stable) {
if rhs as u32 >= Self::BITS {
panic!("Tried to shift with overflow: {self} << {rhs}");
} else {
self.strict_shr(rhs as u32)
}
} else {
self.wrapping_shr(rhs as u32)
}
}
}
@ -3101,3 +3296,22 @@ impl<'a> ShrAssign<&'a u8> for U256 {
*self = *self >> rhs;
}
}
#[cfg(test)]
pub mod tests {
use super::U256;
extern crate std;
#[test]
pub fn check_number_parsing() {
let x = U256::from_str_radix(
"13f4_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_4f76",
16,
)
.unwrap();
assert_eq!(
std::format!("{x}"),
"9025054806887987155511483633794827778846787745240661638085164476202701836150"
);
}
}