Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions crypto/benches/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::fields::f128::BaseElement;
use rand_utils::rand_value;
use utils::uninit_vector;
use core::mem::MaybeUninit;
use utils::{assume_init_vec, uninit_vector};
use winter_crypto::{build_merkle_nodes, concurrent, hashers::Blake3_256, Hasher};

type Blake3 = Blake3_256<BaseElement>;
Expand All @@ -20,11 +21,11 @@ pub fn merkle_tree_construction(c: &mut Criterion) {

for size in &BATCH_SIZES {
let data: Vec<Blake3Digest> = {
let mut res = unsafe { uninit_vector(*size) };
let mut res = uninit_vector(*size);
for i in 0..*size {
res[i] = Blake3::hash(&rand_value::<u128>().to_le_bytes());
res[i] = MaybeUninit::new(Blake3::hash(&rand_value::<u128>().to_le_bytes()));
}
res
unsafe { assume_init_vec(res) }
};
merkle_group.bench_with_input(BenchmarkId::new("sequential", size), &data, |b, i| {
b.iter(|| build_merkle_nodes::<Blake3>(i))
Expand Down
20 changes: 12 additions & 8 deletions crypto/src/merkle/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ pub const MIN_CONCURRENT_LEAVES: usize = 1024;
/// results in a single vector such that root of the tree is at position 1, nodes immediately
/// under the root is at positions 2 and 3 etc.
pub fn build_merkle_nodes<H: Hasher>(leaves: &[H::Digest]) -> Vec<H::Digest> {
use core::mem::MaybeUninit;

let n = leaves.len() / 2;

// create un-initialized array to hold all intermediate nodes
let mut nodes = unsafe { utils::uninit_vector::<H::Digest>(2 * n) };
nodes[0] = H::Digest::default();
let mut nodes = utils::uninit_vector::<H::Digest>(2 * n);
nodes[0] = MaybeUninit::new(H::Digest::default());

// re-interpret leaves as an array of two leaves fused together and use it to
// build first row of internal nodes (parents of leaves)
let two_leaves = unsafe { slice::from_raw_parts(leaves.as_ptr() as *const [H::Digest; 2], n) };
nodes[n..]
.par_iter_mut()
.zip(two_leaves.par_iter())
.for_each(|(target, source)| *target = H::merge(source));
.for_each(|(target, source)| *target = MaybeUninit::new(H::merge(source)));

// calculate all other tree nodes, we can't use regular iterators here because
// access patterns are rather complicated - so, we use regular threads instead
Expand All @@ -45,19 +47,21 @@ pub fn build_merkle_nodes<H: Hasher>(leaves: &[H::Digest]) -> Vec<H::Digest> {
let num_subtrees = rayon::current_num_threads().next_power_of_two();
let batch_size = n / num_subtrees;

// re-interpret nodes as an array of two nodes fused together
// re-interpret nodes as an array of two nodes fused together; MaybeUninit<T> has the
// same layout as T, so the pointer cast is valid
let two_nodes = unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [H::Digest; 2], n) };

// process each subtree in a separate thread
rayon::scope(|s| {
for i in 0..num_subtrees {
let nodes = unsafe { &mut *(&mut nodes[..] as *mut [H::Digest]) };
let nodes =
unsafe { &mut *(&mut nodes[..] as *mut [MaybeUninit<H::Digest>] as *mut [MaybeUninit<H::Digest>]) };
s.spawn(move |_| {
let mut batch_size = batch_size / 2;
let mut start_idx = n / 2 + batch_size * i;
while start_idx >= num_subtrees {
for k in (start_idx..(start_idx + batch_size)).rev() {
nodes[k] = H::merge(&two_nodes[k]);
nodes[k] = MaybeUninit::new(H::merge(&two_nodes[k]));
}
start_idx /= 2;
batch_size /= 2;
Expand All @@ -68,10 +72,10 @@ pub fn build_merkle_nodes<H: Hasher>(leaves: &[H::Digest]) -> Vec<H::Digest> {

// finish the tip of the tree
for i in (1..num_subtrees).rev() {
nodes[i] = H::merge(&two_nodes[i]);
nodes[i] = MaybeUninit::new(H::merge(&two_nodes[i]));
}

nodes
unsafe { utils::assume_init_vec(nodes) }
}

// TESTS
Expand Down
15 changes: 9 additions & 6 deletions crypto/src/merkle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,29 +342,32 @@ impl<H: Hasher> MerkleTree<H> {
/// This function is exposed primarily for benchmarking purposes. It is not intended to be used
/// directly by the end users of the crate.
pub fn build_merkle_nodes<H: Hasher>(leaves: &[H::Digest]) -> Vec<H::Digest> {
use core::mem::MaybeUninit;

let n = leaves.len() / 2;

// create un-initialized array to hold all intermediate nodes
let mut nodes = unsafe { utils::uninit_vector::<H::Digest>(2 * n) };
nodes[0] = H::Digest::default();
let mut nodes = utils::uninit_vector::<H::Digest>(2 * n);
nodes[0] = MaybeUninit::new(H::Digest::default());

// re-interpret leaves as an array of two leaves fused together
let two_leaves = unsafe { slice::from_raw_parts(leaves.as_ptr() as *const [H::Digest; 2], n) };

// build first row of internal nodes (parents of leaves)
for (i, j) in (0..n).zip(n..nodes.len()) {
nodes[j] = H::merge(&two_leaves[i]);
nodes[j] = MaybeUninit::new(H::merge(&two_leaves[i]));
}

// re-interpret nodes as an array of two nodes fused together
// re-interpret nodes as an array of two nodes fused together; safe because all elements
// from index n onwards are initialized, and lower indices will be initialized below
let two_nodes = unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [H::Digest; 2], n) };

// calculate all other tree nodes
for i in (1..n).rev() {
nodes[i] = H::merge(&two_nodes[i]);
nodes[i] = MaybeUninit::new(H::merge(&two_nodes[i]));
}

nodes
unsafe { utils::assume_init_vec(nodes) }
}

fn map_indexes(
Expand Down
5 changes: 3 additions & 2 deletions examples/src/rescue_raps/custom_trace_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use core_utils::uninit_vector;
use core_utils::{assume_init_vec, uninit_vector};
use winterfell::{math::StarkField, matrix::ColMatrix, EvaluationFrame, Trace, TraceInfo};

// RAP TRACE TABLE
Expand Down Expand Up @@ -87,7 +87,8 @@ impl<B: StarkField> RapTraceTable<B> {
meta.len()
);

let columns = unsafe { (0..width).map(|_| uninit_vector(length)).collect() };
// SAFETY: each column is fully initialized via fill() or update_row() before being read.
let columns = (0..width).map(|_| unsafe { assume_init_vec(uninit_vector(length)) }).collect();
Self {
info: TraceInfo::new_multi_segment(width, 3, 3, length, meta),
trace: ColMatrix::new(columns),
Expand Down
7 changes: 4 additions & 3 deletions examples/src/rescue_raps/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use core_utils::uninit_vector;
use core_utils::{assume_init_vec, uninit_vector};
use winterfell::{
crypto::MerkleTree, matrix::ColMatrix, AuxRandElements, CompositionPoly, CompositionPolyTrace,
ConstraintCompositionCoefficients, DefaultConstraintCommitment, DefaultConstraintEvaluator,
Expand Down Expand Up @@ -168,8 +168,9 @@ where
let main_trace = trace.main_segment();
let rand_elements = aux_rand_elements.rand_elements();

let mut current_row = unsafe { uninit_vector(main_trace.num_cols()) };
let mut next_row = unsafe { uninit_vector(main_trace.num_cols()) };
// SAFETY: read_row_into fully initializes each row buffer before it is read.
let mut current_row = unsafe { assume_init_vec(uninit_vector(main_trace.num_cols())) };
let mut next_row = unsafe { assume_init_vec(uninit_vector(main_trace.num_cols())) };
main_trace.read_row_into(0, &mut current_row);
let mut aux_columns = vec![vec![E::ZERO; main_trace.num_rows()]; trace.aux_trace_width()];

Expand Down
9 changes: 5 additions & 4 deletions fri/src/folding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use math::{
};
#[cfg(feature = "concurrent")]
use utils::iterators::*;
use utils::{iter_mut, uninit_vector};
use core::mem::MaybeUninit;
use utils::{assume_init_vec, iter_mut, uninit_vector};

// DEGREE-RESPECTING PROJECTION
// ================================================================================================
Expand Down Expand Up @@ -93,7 +94,7 @@ where
let inv_twiddles = get_inv_twiddles::<B>(N);
let len_offset = E::inv((N as u32).into());

let mut result = unsafe { uninit_vector(values.len()) };
let mut result = uninit_vector(values.len());
iter_mut!(result)
.zip(values)
.zip(inv_offsets)
Expand All @@ -111,10 +112,10 @@ where
}

// evaluate the polynomial at alpha, and save the result
*result = polynom::eval(&poly, alpha)
*result = MaybeUninit::new(polynom::eval(&poly, alpha))
});

result
unsafe { assume_init_vec(result) }
}

// POSITION FOLDING
Expand Down
10 changes: 6 additions & 4 deletions fri/src/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ use crypto::{ElementHasher, Hasher, VectorCommitment};
use math::{fft, FieldElement};
#[cfg(feature = "concurrent")]
use utils::iterators::*;
use core::mem::MaybeUninit;
use utils::{
flatten_vector_elements, group_slice_elements, iter_mut, transpose_slice, uninit_vector,
assume_init_vec, flatten_vector_elements, group_slice_elements, iter_mut, transpose_slice,
uninit_vector,
};

use crate::{
Expand Down Expand Up @@ -326,11 +328,11 @@ where
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
let mut hashed_evaluations: Vec<H::Digest> = unsafe { uninit_vector(values.len()) };
let mut hashed_evaluations: Vec<MaybeUninit<H::Digest>> = uninit_vector(values.len());
iter_mut!(hashed_evaluations, 1024).zip(values).for_each(|(e, v)| {
let digest: H::Digest = H::hash_elements(v);
*e = digest
*e = MaybeUninit::new(digest)
});

V::new(hashed_evaluations)
V::new(unsafe { assume_init_vec(hashed_evaluations) })
}
10 changes: 8 additions & 2 deletions math/src/fft/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

use alloc::vec::Vec;

use utils::{iterators::*, rayon, uninit_vector};
use core::mem::MaybeUninit;
use utils::{assume_init_vec, iterators::*, rayon, uninit_vector};

use super::fft_inputs::FftInputs;
use crate::field::{FieldElement, StarkField};
Expand All @@ -31,7 +32,7 @@ pub fn evaluate_poly_with_offset<B: StarkField, E: FieldElement<BaseField = B>>(
) -> Vec<E> {
let domain_size = p.len() * blowup_factor;
let g = B::get_root_of_unity(domain_size.ilog2());
let mut result = unsafe { uninit_vector(domain_size) };
let mut result = uninit_vector(domain_size);

result
.as_mut_slice()
Expand All @@ -40,10 +41,15 @@ pub fn evaluate_poly_with_offset<B: StarkField, E: FieldElement<BaseField = B>>(
.for_each(|(i, chunk)| {
let idx = super::permute_index(blowup_factor, i) as u64;
let offset = g.exp(idx.into()) * domain_offset;
// SAFETY: MaybeUninit<E> has the same layout as E; we fully initialize
// the chunk via clone_and_shift before reading via split_radix_fft.
let chunk = unsafe { &mut *(chunk as *mut [MaybeUninit<E>] as *mut [E]) };
clone_and_shift(p, chunk, offset);
split_radix_fft(chunk, twiddles);
});

let mut result = unsafe { assume_init_vec(result) };

permute(&mut result);
result
}
Expand Down
12 changes: 9 additions & 3 deletions math/src/fft/serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

use alloc::vec::Vec;

use utils::uninit_vector;
use core::mem::MaybeUninit;
use utils::{assume_init_vec, uninit_vector};

use super::fft_inputs::FftInputs;
use crate::{field::StarkField, FieldElement};
Expand Down Expand Up @@ -38,16 +39,21 @@ where
{
let domain_size = p.len() * blowup_factor;
let g = B::get_root_of_unity(domain_size.ilog2());
let mut result = unsafe { uninit_vector(domain_size) };
let mut result = uninit_vector(domain_size);

result.as_mut_slice().chunks_mut(p.len()).enumerate().for_each(|(i, chunk)| {
let idx = super::permute_index(blowup_factor, i) as u64;
let offset = g.exp(idx.into()) * domain_offset;
let mut factor = E::BaseField::ONE;
for (d, c) in chunk.iter_mut().zip(p.iter()) {
*d = (*c).mul_base(factor);
*d = MaybeUninit::new((*c).mul_base(factor));
factor *= offset;
}
});

let mut result = unsafe { assume_init_vec(result) };

result.as_mut_slice().chunks_mut(p.len()).for_each(|chunk| {
chunk.fft_in_place(twiddles);
});

Expand Down
8 changes: 5 additions & 3 deletions math/src/polynom/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,9 +662,11 @@ where
/// assert_eq!(expected_poly, poly);
/// ```
pub fn poly_from_roots<E: FieldElement>(xs: &[E]) -> Vec<E> {
let mut result = unsafe { utils::uninit_vector(xs.len() + 1) };
fill_zero_roots(xs, &mut result);
result
let mut result = utils::uninit_vector(xs.len() + 1);
// fill_zero_roots writes all elements of result
let result_slice = unsafe { &mut *(result.as_mut_slice() as *mut [core::mem::MaybeUninit<E>] as *mut [E]) };
fill_zero_roots(xs, result_slice);
unsafe { utils::assume_init_vec(result) }
}

// HELPER FUNCTIONS
Expand Down
31 changes: 21 additions & 10 deletions math/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

use alloc::vec::Vec;

use core::mem::MaybeUninit;

#[cfg(feature = "concurrent")]
use utils::iterators::*;
use utils::{batch_iter_mut, iter_mut, uninit_vector};
use utils::{assume_init_vec, batch_iter_mut, iter_mut, uninit_vector};

use crate::{field::FieldElement, ExtensionOf};

Expand Down Expand Up @@ -37,12 +39,15 @@ pub fn get_power_series<E>(b: E, n: usize) -> Vec<E>
where
E: FieldElement,
{
let mut result = unsafe { uninit_vector(n) };
batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| {
let mut result = uninit_vector(n);
batch_iter_mut!(&mut result, 1024, |batch: &mut [MaybeUninit<E>], batch_offset: usize| {
// SAFETY: MaybeUninit<E> has the same layout as E; fill_power_series initializes
// every element of the batch.
let batch = unsafe { &mut *(batch as *mut [MaybeUninit<E>] as *mut [E]) };
let start = b.exp((batch_offset as u64).into());
fill_power_series(batch, b, start);
});
result
unsafe { assume_init_vec(result) }
}

/// Returns a vector containing successive powers of a given base offset by the specified value.
Expand Down Expand Up @@ -70,12 +75,15 @@ pub fn get_power_series_with_offset<E>(b: E, s: E, n: usize) -> Vec<E>
where
E: FieldElement,
{
let mut result = unsafe { uninit_vector(n) };
batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| {
let mut result = uninit_vector(n);
batch_iter_mut!(&mut result, 1024, |batch: &mut [MaybeUninit<E>], batch_offset: usize| {
// SAFETY: MaybeUninit<E> has the same layout as E; fill_power_series initializes
// every element of the batch.
let batch = unsafe { &mut *(batch as *mut [MaybeUninit<E>] as *mut [E]) };
let start = s * b.exp((batch_offset as u64).into());
fill_power_series(batch, b, start);
});
result
unsafe { assume_init_vec(result) }
}

/// Computes element-wise sum of the provided vectors, and stores the result in the first vector.
Expand Down Expand Up @@ -170,13 +178,16 @@ pub fn batch_inversion<E>(values: &[E]) -> Vec<E>
where
E: FieldElement,
{
let mut result: Vec<E> = unsafe { uninit_vector(values.len()) };
batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| {
let mut result: Vec<MaybeUninit<E>> = uninit_vector(values.len());
batch_iter_mut!(&mut result, 1024, |batch: &mut [MaybeUninit<E>], batch_offset: usize| {
// SAFETY: MaybeUninit<E> has the same layout as E; serial_batch_inversion
// initializes every element of the batch.
let batch = unsafe { &mut *(batch as *mut [MaybeUninit<E>] as *mut [E]) };
let start = batch_offset;
let end = start + batch.len();
serial_batch_inversion(&values[start..end], batch);
});
result
unsafe { assume_init_vec(result) }
}

// HELPER FUNCTIONS
Expand Down
Loading