Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,18 @@ parallel = [
rayon = ["dep:rayon"]
asm = ["ark-ff/asm"]
tracing = ["dep:tracing"]
# Enable per-step allocation tracking in ZK WHIR.
# Run: cargo run --bin alloc_report --features alloc-track
alloc-track = []

[[bench]]
name = "expand_from_coeff"
harness = false

[[bench]]
name = "whir_zk"
harness = false

[[bench]]
name = "sumcheck"
harness = false
Expand Down
465 changes: 465 additions & 0 deletions benches/whir_zk.rs

Large diffs are not rendered by default.

79 changes: 78 additions & 1 deletion src/algebra/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod polynomials;
pub mod sumcheck;
mod weights;

use ark_ff::{AdditiveGroup, Field};
use ark_ff::{AdditiveGroup, FftField, Field};
#[cfg(feature = "parallel")]
use rayon::prelude::*;

Expand Down Expand Up @@ -54,6 +54,13 @@ pub fn lift<M: Embedding>(embedding: &M, source: &[M::Source]) -> Vec<M::Target>
result
}

/// Scalar-mul add (same-field AXPY)
///
/// accumulator[i] += weight * vector[i]
pub fn scalar_mul_add<F: Field>(accumulator: &mut [F], weight: F, vector: &[F]) {
mixed_scalar_mul_add(&embedding::Identity::new(), accumulator, weight, vector);
}

/// Mixed scalar-mul add
///
/// accumulator[i] += weight * vector[i]
Expand Down Expand Up @@ -119,3 +126,73 @@ pub fn mixed_dot<F: Field, G: Field>(

result
}

/// Project an extension field element to its base prime field component.
///
/// Panics if the element does not lie in the base prime subfield.
#[inline]
pub fn project_to_base<F: Field>(val: F) -> F::BasePrimeField {
val.to_base_prime_field_elements()
.next()
.expect("element should lie in base prime subfield")
}

/// Project every element of an extension-field slice to the base prime field.
///
/// Panics if any element does not lie in the base prime subfield.
pub fn project_all_to_base<F: FftField>(coeffs: &[F]) -> Vec<F::BasePrimeField> {
#[cfg(feature = "parallel")]
{
coeffs.par_iter().map(|c| project_to_base(*c)).collect()
}
#[cfg(not(feature = "parallel"))]
{
coeffs.iter().map(|&c| project_to_base(c)).collect()
}
}

/// Element-wise add a base-field slice with a (possibly shorter) extension-field
/// slice projected to base field.
///
/// Computes `result[i] = base[i] + project_to_base(ext[i])` for `i < ext.len()`,
/// and `result[i] = base[i]` for `i >= ext.len()`.
///
/// Each element of `ext` must lie in the base prime subfield.
pub fn add_base_with_projection<F: FftField>(
base: &[F::BasePrimeField],
ext_addend: &[F],
) -> Vec<F::BasePrimeField> {
debug_assert!(
ext_addend.len() <= base.len(),
"ext_addend ({}) must not exceed base ({})",
ext_addend.len(),
base.len(),
);
let ext_len = ext_addend.len();

#[cfg(feature = "parallel")]
{
(0..base.len())
.into_par_iter()
.map(|i| {
if i < ext_len {
base[i] + project_to_base(ext_addend[i])
} else {
base[i]
}
})
.collect()
}
#[cfg(not(feature = "parallel"))]
{
(0..base.len())
.map(|i| {
if i < ext_len {
base[i] + project_to_base(ext_addend[i])
} else {
base[i]
}
})
.collect()
}
}
47 changes: 47 additions & 0 deletions src/algebra/polynomials/coeffs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ impl<F: Field> CoefficientList<F> {
}
}

/// Embed an ℓ-variate polynomial into an n-variate polynomial (n ≥ ℓ)
/// by treating the extra variables as having zero contribution.
///
/// Coefficient at index `i` in the ℓ-variate maps to index `i * 2^(n-ℓ)`
/// in the n-variate, with all other coefficients set to zero.
pub fn embed_into_variables(&self, n: usize) -> Self {
let ell = self.num_variables;
assert!(n >= ell);

let factor = 1 << (n - ell);
let new_size = 1 << n;
let mut coeffs = vec![F::ZERO; new_size];

for (i, &c) in self.coeffs.iter().enumerate() {
coeffs[i * factor] = c;
}

Self::new(coeffs)
}

/// Evaluates the polynomial at an arbitrary point in `F^n`.
///
/// This generalizes evaluation beyond `(0,1)^n`, allowing fractional or arbitrary field
Expand Down Expand Up @@ -137,6 +157,33 @@ impl<F: Field> CoefficientList<F> {
num_variables: self.num_variables() - folding_factor,
}
}

/// Folds the polynomial in-place along high-indexed variables.
///
/// Like [`fold`](Self::fold), but modifies the polynomial in-place instead of
/// allocating a new coefficient vector. The excess capacity is freed via truncation.
///
/// # Safety of in-place overwrite
///
/// For each output index `i`, `eval_multivariate` reads from
/// `coeffs[i*chunk .. (i+1)*chunk]` and the result is written to `coeffs[i]`.
/// Since `chunk >= 2`, the write target `i` is always strictly less than the
/// start of the next read range `(i+1)*chunk`, so writes never corrupt data
/// needed by subsequent iterations.
pub fn fold_in_place(&mut self, folding_randomness: &MultilinearPoint<F>) {
let folding_factor = folding_randomness.num_variables();
let chunk_size = 1 << folding_factor;
let new_len = self.coeffs.len() / chunk_size;
for i in 0..new_len {
let val = eval_multivariate(
&self.coeffs[i * chunk_size..(i + 1) * chunk_size],
&folding_randomness.0,
);
self.coeffs[i] = val;
}
self.coeffs.truncate(new_len);
self.num_variables -= folding_factor;
}
}

/// Multivariate evaluation in coefficient form.
Expand Down
20 changes: 20 additions & 0 deletions src/algebra/polynomials/evals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@ where
self.num_variables
}

/// Folds evaluations in-place by linear interpolation at the given weight.
///
/// For each pair `(evals[2i], evals[2i+1])`, computes the interpolated value
/// `(evals[2i+1] - evals[2i]) * weight + evals[2i]` and stores it at `evals[i]`.
/// The vector is then truncated to half its original size.
///
/// This is equivalent to creating a new `EvaluationsList` via
/// `algebra::sumcheck::fold`, but avoids allocating a new vector.
pub fn fold_in_place(&mut self, weight: F) {
assert!(self.evals.len().is_multiple_of(2));
let half = self.evals.len() / 2;
for i in 0..half {
let v0 = self.evals[2 * i];
let v1 = self.evals[2 * i + 1];
self.evals[i] = (v1 - v0) * weight + v0;
}
self.evals.truncate(half);
self.num_variables -= 1;
}

pub fn to_coeffs(&self) -> crate::algebra::polynomials::coeffs::CoefficientList<F> {
let mut coeffs = self.evals.clone();
crate::algebra::ntt::inverse_wavelet_transform(&mut coeffs);
Expand Down
29 changes: 24 additions & 5 deletions src/algebra/polynomials/multilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,31 @@ where
acc
}

/// Computes eq(c, p) on the hypercube for all p.
/// Computes eq(self, z) for every z ∈ {0,1}ⁿ using a butterfly expansion.
///
/// Returns a `Vec` of length `2^n` where entry `z` (in lexicographic
/// order) is `eq(self, z)`.
///
/// Runs in O(2ⁿ) time and O(2ⁿ) space.
pub fn eq_weights(&self) -> Vec<F> {
(0..1 << self.0.len())
.map(BinaryHypercubePoint)
.map(|point| self.eq_poly(point))
.collect()
let n = self.num_variables();
let size = 1 << n;
let mut evals = Vec::with_capacity(size);
evals.push(F::ONE);
// Process coordinates in storage order (big-endian: x_{n-1}, …, x_0).
// Each step doubles the vector via the identity:
// eq(c, z||0) = eq(c', z) · (1 − cᵢ)
// eq(c, z||1) = eq(c', z) · cᵢ
for &ci in &self.0 {
let len = evals.len();
let one_minus_ci = F::ONE - ci;
evals.resize(2 * len, F::ZERO);
for j in (0..len).rev() {
evals[2 * j + 1] = evals[j] * ci;
evals[2 * j] = evals[j] * one_minus_ci;
}
}
evals
}

pub fn coeff_weights(&self, reversed: bool) -> Vec<F> {
Expand Down
140 changes: 140 additions & 0 deletions src/algebra/weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,144 @@ mod tests {
let expected = weight_list.eval_extension(&folding_randomness);
assert_eq!(weight.compute(&folding_randomness), expected);
}

#[test]
fn test_protocol() {
// ── Step 1: Create a CoefficientList (4 variables, 16 coefficients) ──
let coeffs = CoefficientList::new(vec![
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
Field64::ONE,
]);
println!("coeffs: {:?}\n", coeffs);

// ── Step 2: Evaluate at several MultilinearPoints ──
let evaluation_points = vec![
MultilinearPoint(vec![
Field64::ONE,
Field64::ZERO,
Field64::ZERO,
Field64::ZERO,
]),
MultilinearPoint(vec![
Field64::ZERO,
Field64::ONE,
Field64::ZERO,
Field64::ZERO,
]),
MultilinearPoint(vec![
Field64::ZERO,
Field64::ZERO,
Field64::ONE,
Field64::ZERO,
]),
MultilinearPoint(vec![
Field64::ZERO,
Field64::ZERO,
Field64::ZERO,
Field64::ONE,
]),
];
println!("evaluation_points: {:?}\n", evaluation_points);
let weights = evaluation_points
.iter()
.map(|point| Weights::evaluation(point.clone()))
.collect::<Vec<_>>();
println!("weights: {:?}\n", weights);
let evaluations = evaluation_points
.iter()
.map(|point| coeffs.mixed_evaluate(&Identity::new(), point))
.collect::<Vec<_>>();
println!("evaluations: {:?}\n", evaluations);

// ── Step 3: Convert CoefficientList → EvaluationsList → CoefficientList ──
// CoefficientList → EvaluationsList (via wavelet transform)
let evals = EvaluationsList::from(coeffs.clone());
println!("evals (hypercube evaluations): {:?}\n", evals);

// EvaluationsList → CoefficientList (via inverse wavelet transform)
let coeffs_roundtrip = evals.to_coeffs();
println!("coeffs_roundtrip: {:?}\n", coeffs_roundtrip);

// Verify round-trip: coeffs → evals → coeffs gives back the same polynomial
assert_eq!(
coeffs.coeffs(),
coeffs_roundtrip.coeffs(),
"Round-trip CoefficientList → EvaluationsList → CoefficientList must be identity"
);

// Both representations should evaluate to the same values at any point
for point in &evaluation_points {
let from_coeffs = coeffs.evaluate(point);
let from_evals = evals.evaluate(point);
assert_eq!(
from_coeffs, from_evals,
"CoefficientList and EvaluationsList must agree at {:?}",
point
);
}
println!("✓ Round-trip and evaluation consistency verified\n");

// ── Step 4: Verify fold_in_place matches fold ──
// fold() creates a new polynomial; fold_in_place() mutates in-place.
// After folding f(X₀, X₁, X₂, X₃) at (r₀, r₁), we get g(X₂, X₃) = f(X₂, X₃, r₀, r₁).
let folding_randomness = MultilinearPoint(vec![Field64::from(3u64), Field64::from(7u64)]);

// fold() — allocating version
let folded = coeffs.fold(&folding_randomness);
println!("folded (via fold): {:?}", folded);

// fold_in_place() — in-place version
let mut coeffs_mut = coeffs.clone();
coeffs_mut.fold_in_place(&folding_randomness);
println!("folded (via fold_in_place): {:?}", coeffs_mut);

// They must produce identical results
assert_eq!(
folded.coeffs(),
coeffs_mut.coeffs(),
"fold() and fold_in_place() must produce the same polynomial"
);
println!("✓ fold and fold_in_place match\n");

// ── Step 5: Verify folded polynomial is consistent with full evaluation ──
// g(a, b) should equal f(a, b, r₀, r₁) for any (a, b)
let eval_point = MultilinearPoint(vec![Field64::from(5u64), Field64::from(11u64)]);
println!("eval_point: {:?}\n", eval_point);
let full_point = MultilinearPoint(vec![
eval_point.0[0],
eval_point.0[1],
folding_randomness.0[0],
folding_randomness.0[1],
]);
println!("full_point: {:?}\n", full_point);
let folded_eval = folded.evaluate(&eval_point);
println!("folded poly: {:?}\n", folded);
let full_eval = coeffs.evaluate(&full_point);
println!("full poly: {:?}\n", coeffs);

println!("folded_eval: {:?}\n", folded_eval);
println!("full_eval: {:?}\n", full_eval);
assert_eq!(
folded_eval, full_eval,
"f.fold(r).evaluate(a) must equal f.evaluate(a || r)"
);
println!(
"✓ folded.evaluate({:?}) == coeffs.evaluate({:?}) == {:?}\n",
eval_point.0, full_point.0, folded_eval
);
}
}
Loading