Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ rand_distr = "0.5.1"
dashmap = "6.1.0"
keccak-asm = { version = "0.1.4" }
walkdir = "2"
num-integer = "0.1.46"

[dev-dependencies]
proptest = "1.0.0"
Expand Down
4 changes: 2 additions & 2 deletions src/bgg/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::poly::{Poly, PolyMatrix};
use rayon::prelude::*;
use std::ops::{Add, Mul, Sub};

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct BggEncoding<M: PolyMatrix> {
pub vector: M,
pub pubkey: BggPublicKey<M>,
Expand Down Expand Up @@ -168,11 +168,11 @@ impl<M: PolyMatrix> Evaluable for BggEncoding<M> {

#[cfg(test)]
mod tests {
use super::*;
use crate::{
bgg::{
circuit::PolyCircuit,
sampler::{BGGEncodingSampler, BGGPublicKeySampler},
BggEncoding,
},
poly::{
dcrt::{
Expand Down
327 changes: 327 additions & 0 deletions src/fhe/bfv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
use num_bigint::BigUint;
use num_integer::Integer;
use std::ops::{Add, Mul};

use crate::poly::{
dcrt::{DCRTPoly, DCRTPolyParams, DCRTPolyUniformSampler, FinRingElem},
sampler::{DistType, PolyUniformSampler},
Poly, PolyParams,
};

fn delta(params_q: &DCRTPolyParams, params_t: &DCRTPolyParams) -> BigUint {
params_q.modulus().div_floor(&params_t.modulus())
}

pub struct Bfv {
params_t: DCRTPolyParams,
params_q: DCRTPolyParams,
delta: BigUint,
sigma: f64,
rlk0: DCRTPoly,
rlk1: DCRTPoly,
}

#[derive(Debug, Clone)]
pub struct BfvCipher {
c_1: DCRTPoly,
c_2: DCRTPoly,
}

impl Bfv {
pub fn keygen(
params_t: DCRTPolyParams,
params_q: DCRTPolyParams,
sigma: f64,
) -> (Self, DCRTPoly) {
let sampler = DCRTPolyUniformSampler::new();
let sk_t = sampler.sample_uniform(&params_t, 1, 1, DistType::BitDist).entry(0, 0);
let delta = delta(&params_q, &params_t);
let a = sampler.sample_uniform(&params_q, 1, 1, DistType::FinRingDist).entry(0, 0);
let e = sampler.sample_uniform(&params_q, 1, 1, DistType::GaussDist { sigma }).entry(0, 0);
let one_elem = FinRingElem::new(BigUint::from(1_u8), params_q.modulus());
let sk_q = sk_t.scalar_mul(&params_q, one_elem);
let rlk0 = -(&a * &sk_q) + &sk_q + e;
let rlk1 = a;
(Self { params_t, params_q, delta, sigma, rlk0, rlk1 }, sk_t)
}

pub fn encrypt_ske(&self, m_t: DCRTPoly, sk_t: DCRTPoly) -> BfvCipher {
let delta_elem = FinRingElem::new(self.delta.clone(), self.params_q.modulus());
let m_q = m_t.scalar_mul(&self.params_q, delta_elem);
let sampler_uniform = DCRTPolyUniformSampler::new();
let a =
sampler_uniform.sample_uniform(&self.params_q, 1, 1, DistType::FinRingDist).entry(0, 0);
let e = sampler_uniform
.sample_uniform(&self.params_q, 1, 1, DistType::GaussDist { sigma: self.sigma })
.entry(0, 0);
/*
this is hacky way to modular switch sk on mod t to mod q where t < q. I've introduced
using scaler mul because existing modular switch doesn't support t < q case.
*/
let one_elem = FinRingElem::new(BigUint::from(1_u8), self.params_q.modulus());
let sk_q = sk_t.scalar_mul(&self.params_q, one_elem);
let c_1 = &sk_q * &a + m_q + &e;
let c_2 = -a;

BfvCipher { c_1, c_2 }
}

pub fn decrypt(&self, ct: BfvCipher, sk_t: DCRTPoly) -> DCRTPoly {
/*
again, modular switch on sk from t to q
*/
let one_elem = FinRingElem::new(BigUint::from(1_u8), self.params_q.modulus());
let sk_q = sk_t.scalar_mul(&self.params_q, one_elem);
let ct = ct.c_1 + ct.c_2 * sk_q;
ct.scale_and_round(&self.params_t)
}

pub fn mul(&self, lhs: &BfvCipher, rhs: &BfvCipher) -> BfvCipher {
let d0 = &lhs.c_1 * &rhs.c_1;
let d1 = &lhs.c_1 * &rhs.c_2 + &lhs.c_2 * &rhs.c_1;
let d2 = &lhs.c_2 * &rhs.c_2;

let c0 = &d0 + &d2 * &self.rlk0;
let c1 = &d1 + &d2 * &self.rlk1;

let c0_t = c0.scale_and_round(&self.params_t);
let c1_t = c1.scale_and_round(&self.params_t);

let delta_elem = FinRingElem::new(self.delta.clone(), self.params_q.modulus());
let c0_q = c0_t.scalar_mul(&self.params_q, delta_elem.clone());
let c1_q = c1_t.scalar_mul(&self.params_q, delta_elem);

BfvCipher { c_1: c0_q, c_2: c1_q }
}
}

impl BfvCipher {
pub fn decompose_base(&self, params_q: &DCRTPolyParams) -> Vec<DCRTPoly> {
let mut c1_decomposed = self.c_1.decompose_base(params_q);
let c2_decomposed = self.c_2.decompose_base(params_q);
c1_decomposed.extend(c2_decomposed);
c1_decomposed
}
}

impl Add for BfvCipher {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
let c_1 = self.c_1 + rhs.c_1;
let c_2 = self.c_2 + rhs.c_2;
Self { c_1, c_2 }
}
}

impl Mul for &BfvCipher {
type Output = BfvCipher;
fn mul(self, _: Self) -> Self::Output {
unimplemented!("use bfv.mul(&ct1,&ct2) instead");
}
}

#[cfg(test)]
mod tests {
use keccak_asm::Keccak256;

use super::*;
use crate::{
bgg::{
circuit::PolyCircuit,
sampler::{BGGEncodingSampler, BGGPublicKeySampler},
},
poly::{dcrt::DCRTPolyHashSampler, Poly},
utils::{create_bit_random_poly, create_random_poly},
};

#[test]
fn test_bfv_add() {
/*
Create parameter t and q for testing where they share same ring dimension but with different modulus.
Paintext modulus t should be smaller than Ciphertext modulus q.
*/
let params_t = DCRTPolyParams::new(4, 2, 17, 1);
let params_q = DCRTPolyParams::new(4, 4, 21, 7);

let (bfv, sk) = Bfv::keygen(params_t.clone(), params_q.clone(), 3.2);

let m_a = create_random_poly(&params_t);
let m_b = create_random_poly(&params_t);
let m_add = &m_a + &m_b;
let ct_a = bfv.encrypt_ske(m_a, sk.clone());
let ct_b = bfv.encrypt_ske(m_b, sk.clone());

/* Homomorphic */
let ct_add = ct_a + ct_b;

/* Decryption */
let dec = bfv.decrypt(ct_add, sk);
assert_eq!(m_add, dec);
}

#[test]
fn test_bfv_bgg_add() {
let params_t = DCRTPolyParams::new(4, 2, 17, 1);
let params_q = DCRTPolyParams::new(4, 8, 51, 17);

let (bfv, sk) = Bfv::keygen(params_t.clone(), params_q.clone(), 3.2);
let m_1 = create_random_poly(&params_t);
let ct_1 = bfv.encrypt_ske(m_1, sk.clone());
let m_2 = create_random_poly(&params_t);
let ct_2 = bfv.encrypt_ske(m_2, sk);
let ct_add = ct_1.clone() + ct_2.clone();
let mut plaintexts_sum = vec![ct_1.c_1, ct_1.c_2];
let single_ct_plaintext_len = plaintexts_sum.len();
let plaintexts_2 = vec![ct_2.c_1, ct_2.c_2];
plaintexts_sum.extend(plaintexts_2);
println!("plaintexts_sum {}", plaintexts_sum.len());

/* BGG */
// initiate BGG encoding for (ct_1 || ct_2)
let key: [u8; 32] = rand::random();
let d = (2 * single_ct_plaintext_len) + 1;
let bgg_pubkey_sampler =
BGGPublicKeySampler::<_, DCRTPolyHashSampler<Keccak256>>::new(key, d);
let uniform_sampler = DCRTPolyUniformSampler::new();
let tag: u64 = rand::random();
let tag_bytes = tag.to_le_bytes();
let reveal_plaintexts = vec![true; d];
let pubkeys = bgg_pubkey_sampler.sample(&params_q, &tag_bytes, &reveal_plaintexts);
let secrets = vec![create_bit_random_poly(&params_q); d];
let bgg_encoding_sampler =
BGGEncodingSampler::new(&params_q, &secrets, uniform_sampler, 3.2);
let encodings = bgg_encoding_sampler.sample(&params_q, &pubkeys, &plaintexts_sum);
println!("encodings length {}", encodings.len());
/* Circuit */
let mut circuit = PolyCircuit::new();
let inputs = circuit.input(d - 1);
let output_c_1: usize = circuit.add_gate(inputs[0], inputs[2]);
let output_c_2: usize = circuit.add_gate(inputs[1], inputs[3]);
circuit.output(vec![output_c_1, output_c_2]);
let result = circuit.eval(&params_q, &encodings[0], &encodings[1..]);
println!("result length {}", result.len());
for r_i in result.clone() {
println!("result from circuit eval: {:?}", r_i.plaintext.unwrap().coeffs());
}

/* Expected */
// (ct_1 + ct_2)
let add_plaintext = vec![ct_add.c_1, ct_add.c_2];
println!("add_plaintext length {}", add_plaintext.len());
// sample BGG encoding for decomposition
let d = single_ct_plaintext_len + 1;
let uniform_sampler = DCRTPolyUniformSampler::new();
let bgg_pubkey_sampler =
BGGPublicKeySampler::<_, DCRTPolyHashSampler<Keccak256>>::new(key, d);
let reveal_plaintexts = vec![true; d];
let pubkeys = bgg_pubkey_sampler.sample(&params_q, &tag_bytes, &reveal_plaintexts);
let secrets = vec![create_bit_random_poly(&params_q); d];
let bgg_encoding_sampler =
BGGEncodingSampler::new(&params_q, &secrets, uniform_sampler, 3.2);
let expected_result = bgg_encoding_sampler.sample(&params_q, &pubkeys, &add_plaintext);
println!("expected_result length {}", expected_result.len());

for e_i in expected_result.clone() {
println!("result from ct_add {:?}", e_i.plaintext.unwrap().coeffs());
}
assert_eq!(result[0].plaintext, expected_result[1].plaintext);
assert_eq!(result[1].plaintext, expected_result[2].plaintext);
}

#[test]
fn test_bfv_bgg_mul() {
let params_t = DCRTPolyParams::new(4, 2, 17, 1);
let params_q = DCRTPolyParams::new(4, 4, 19, 2);

let (bfv, sk) = Bfv::keygen(params_t.clone(), params_q.clone(), 3.2);
let m_1 = create_random_poly(&params_t);
let ct_1 = bfv.encrypt_ske(m_1, sk.clone());
let m_2 = create_random_poly(&params_t);
let ct_2 = bfv.encrypt_ske(m_2, sk);
let ct_mul = bfv.mul(&ct_1, &ct_2);
let mut plaintexts_sum = vec![ct_1.c_1, ct_1.c_2];
let single_ct_plaintext_len = plaintexts_sum.len();
let plaintexts_2 = vec![ct_2.c_1, ct_2.c_2];
plaintexts_sum.extend(plaintexts_2);
println!("plaintexts_sum {}", plaintexts_sum.len());

/* BGG */
// initiate BGG encoding for (ct_1 || ct_2)
let key: [u8; 32] = rand::random();
let d = (2 * single_ct_plaintext_len) + 1;
let bgg_pubkey_sampler =
BGGPublicKeySampler::<_, DCRTPolyHashSampler<Keccak256>>::new(key, d);
let uniform_sampler = DCRTPolyUniformSampler::new();
let tag: u64 = rand::random();
let tag_bytes = tag.to_le_bytes();
let reveal_plaintexts = vec![true; d];
let pubkeys = bgg_pubkey_sampler.sample(&params_q, &tag_bytes, &reveal_plaintexts);
let secrets = vec![create_bit_random_poly(&params_q); d];
let bgg_encoding_sampler =
BGGEncodingSampler::new(&params_q, &secrets, uniform_sampler, 3.2);
let encodings = bgg_encoding_sampler.sample(&params_q, &pubkeys, &plaintexts_sum);
println!("encodings length {}", encodings.len());
/* Circuit */
let mut circuit = PolyCircuit::new();
let inputs = circuit.input(d - 1);
let output_c_1: usize = circuit.mul_gate(inputs[0], inputs[2]);
let output_c_2: usize = circuit.mul_gate(inputs[1], inputs[3]);
circuit.output(vec![output_c_1, output_c_2]);
let result = circuit.eval(&params_q, &encodings[0], &encodings[1..]);
println!("result length {}", result.len());
for r_i in result.clone() {
println!("result from circuit eval: {:?}", r_i.plaintext.unwrap().coeffs());
}

/* Expected */
// (ct_1 * ct_2)
let add_plaintext = vec![ct_mul.c_1, ct_mul.c_2];
println!("add_plaintext length {}", add_plaintext.len());
// sample BGG encoding for decomposition
let d = single_ct_plaintext_len + 1;
let uniform_sampler = DCRTPolyUniformSampler::new();
let bgg_pubkey_sampler =
BGGPublicKeySampler::<_, DCRTPolyHashSampler<Keccak256>>::new(key, d);
let reveal_plaintexts = vec![true; d];
let pubkeys = bgg_pubkey_sampler.sample(&params_q, &tag_bytes, &reveal_plaintexts);
let secrets = vec![create_bit_random_poly(&params_q); d];
let bgg_encoding_sampler =
BGGEncodingSampler::new(&params_q, &secrets, uniform_sampler, 3.2);
let expected_result = bgg_encoding_sampler.sample(&params_q, &pubkeys, &add_plaintext);
println!("expected_result length {}", expected_result.len());

for e_i in expected_result.clone() {
println!("result from ct_mul {:?}", e_i.plaintext.unwrap().coeffs());
}
//todo: err
// assert_eq!(result[0].plaintext, expected_result[1].plaintext);
// assert_eq!(result[1].plaintext, expected_result[2].plaintext);
}

#[test]
fn test_bfv_mul() {
let params_t = DCRTPolyParams::new(4, 2, 17, 1);
let params_q = DCRTPolyParams::new(4, 8, 51, 17);

let (bfv, sk) = Bfv::keygen(params_t.clone(), params_q.clone(), 0.0);

let m_a = create_random_poly(&params_t);
println!("m_a = {:?}", m_a.coeffs());
let m_b = create_random_poly(&params_t);
println!("m_b = {:?}", m_b.coeffs());
let m_prod = &m_a * &m_b;
println!("m_prod = {:?}", m_prod.coeffs());
let ct_a = bfv.encrypt_ske(m_a, sk.clone());
let ct_b = bfv.encrypt_ske(m_b, sk.clone());

/* Homomorphic */
let ct_prod = bfv.mul(&ct_a, &ct_b);

/* Decryption */
let dec = bfv.decrypt(ct_prod, sk);
println!("dec = {:?}", dec.coeffs());
// todo: error
// assert_eq!(m_prod, dec);
}
}
1 change: 1 addition & 0 deletions src/fhe/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod bfv;
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![allow(clippy::too_many_arguments)]

pub mod bgg;
pub mod fhe;
pub mod io;
pub mod poly;
pub mod test_utils;
Expand Down
13 changes: 13 additions & 0 deletions src/poly/dcrt/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ impl FinRingElem {
((&self.value * new_modulus.as_ref()) / self.modulus.as_ref()) % new_modulus.as_ref();
Self { value, modulus: self.modulus.clone() }
}

pub fn modulus_switch_round(&self, new_modulus: Arc<BigUint>) -> Self {
let q = self.modulus.as_ref();
let t = new_modulus.as_ref();

// scaled = (value * t + ⌊q/2⌋) / q
let mut scaled = &self.value * t;
scaled += q >> 1;
scaled /= q;
scaled %= t;

Self { value: scaled, modulus: new_modulus }
}
}

impl PolyElem for FinRingElem {
Expand Down
Loading
Loading