Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/sframe/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ class map : private vector<std::optional<std::pair<K, V>>, N>
return pos->value().second;
}

void erase(const K& key)
{
auto pos = find(key);
if (pos != this->end()) {
pos->reset();
}
}

template<typename F>
void erase_if_key(F&& f)
{
Expand Down
4 changes: 4 additions & 0 deletions include/sframe/sframe.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class Context
virtual ~Context();

Result<void> add_key(KeyID kid, KeyUsage usage, input_bytes key);
void remove_key(KeyID kid);

Result<output_bytes> protect(KeyID key_id,
output_bytes ciphertext,
Expand All @@ -134,6 +135,8 @@ class Context
CipherSuite suite;
map<KeyID, KeyRecord, SFRAME_MAX_KEYS> keys;

Result<void> require_key(KeyID key_id) const;

Result<output_bytes> protect_inner(const Header& header,
output_bytes ciphertext,
input_bytes plaintext,
Expand All @@ -160,6 +163,7 @@ class MLSContext : protected Context
Result<void> add_epoch(EpochID epoch_id,
input_bytes sframe_epoch_secret,
size_t sender_bits);
void remove_epoch(EpochID epoch_id);
void purge_before(EpochID keeper);

Result<output_bytes> protect(EpochID epoch_id,
Expand Down
30 changes: 30 additions & 0 deletions src/sframe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ Context::Context(CipherSuite suite_in)

Context::~Context() = default;

void
Context::remove_key(KeyID key_id)
{
keys.erase(key_id);
}

Result<void>
Context::add_key(KeyID key_id, KeyUsage usage, input_bytes base_key)
{
Expand Down Expand Up @@ -118,12 +124,23 @@ form_aad(const Header& header, input_bytes metadata)
return aad;
}

Result<void>
Context::require_key(KeyID key_id) const
{
if (!keys.contains(key_id)) {
return SFrameError(SFrameErrorType::invalid_parameter_error,
"Unknown key ID");
}
return Result<void>::ok();
}

Result<output_bytes>
Context::protect(KeyID key_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata)
{
SFRAME_VOID_OR_RETURN(require_key(key_id));
auto& key_record = keys.at(key_id);
const auto counter = key_record.counter;
key_record.counter += 1;
Expand Down Expand Up @@ -166,6 +183,7 @@ Context::protect_inner(const Header& header,
"Ciphertext too small for cipher overhead");
}

SFRAME_VOID_OR_RETURN(require_key(header.key_id));
const auto& key_and_salt = keys.at(header.key_id);

SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata));
Expand All @@ -190,6 +208,7 @@ Context::unprotect_inner(const Header& header,
"Plaintext too small for decrypted value");
}

SFRAME_VOID_OR_RETURN(require_key(header.key_id));
const auto& key_and_salt = keys.at(header.key_id);

SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata));
Expand Down Expand Up @@ -326,6 +345,17 @@ MLSContext::EpochKeys::base_key(CipherSuite ciphersuite,
ciphersuite, sframe_epoch_secret, enc_sender_id, hash_size);
}

void
MLSContext::remove_epoch(EpochID epoch_id)
{
purge_epoch(epoch_id);

const auto idx = epoch_id & epoch_mask;
if (idx < epoch_cache.size()) {
epoch_cache[idx].reset();
}
}

void
MLSContext::purge_epoch(EpochID epoch_id)
{
Expand Down
113 changes: 113 additions & 0 deletions test/sframe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,116 @@ TEST_CASE("MLS Failure after Purge")
const auto dec_ab_2 = member_b.unprotect(pt_out, enc_ab_2, metadata).unwrap();
CHECK(plaintext == to_bytes(dec_ab_2));
}

TEST_CASE("SFrame Context Remove Key")
{
const auto suite = CipherSuite::AES_GCM_128_SHA256;
const auto kid = KeyID(0x07);
const auto key = from_hex("000102030405060708090a0b0c0d0e0f");
const auto plaintext = from_hex("00010203");
const auto metadata = bytes{};

auto pt_out = bytes(plaintext.size());
auto ct_out = bytes(plaintext.size() + Context::max_overhead);

auto sender = Context(suite);
auto receiver = Context(suite);
sender.add_key(kid, KeyUsage::protect, key).unwrap();
receiver.add_key(kid, KeyUsage::unprotect, key).unwrap();

// Protect and unprotect succeed before removal
auto encrypted =
to_bytes(sender.protect(kid, ct_out, plaintext, metadata).unwrap());
auto decrypted =
to_bytes(receiver.unprotect(pt_out, encrypted, metadata).unwrap());
CHECK(decrypted == plaintext);

// Remove sender key and verify protect fails
sender.remove_key(kid);
CHECK(sender.protect(kid, ct_out, plaintext, metadata).error().type() ==
SFrameErrorType::invalid_parameter_error);

// Remove receiver key and verify unprotect fails
receiver.remove_key(kid);
CHECK(receiver.unprotect(pt_out, encrypted, metadata).error().type() ==
SFrameErrorType::invalid_parameter_error);

// Re-add keys and verify round-trip works again
sender.add_key(kid, KeyUsage::protect, key).unwrap();
receiver.add_key(kid, KeyUsage::unprotect, key).unwrap();

encrypted =
to_bytes(sender.protect(kid, ct_out, plaintext, metadata).unwrap());
decrypted =
to_bytes(receiver.unprotect(pt_out, encrypted, metadata).unwrap());
CHECK(decrypted == plaintext);
}

TEST_CASE("SFrame Context Remove Key - Nonexistent Key")
{
const auto suite = CipherSuite::AES_GCM_128_SHA256;

auto ctx = Context(suite);

// Removing a key that was never added should not throw
CHECK_NOTHROW(ctx.remove_key(KeyID(0x99)));
}

TEST_CASE("MLS Remove Epoch")
{
const auto suite = CipherSuite::AES_GCM_128_SHA256;
const auto epoch_bits = 2;
const auto metadata = from_hex("00010203");
const auto plaintext = from_hex("04050607");
const auto sender_id = MLSContext::SenderID(0xA0A0A0A0);
const auto sframe_epoch_secret_1 = bytes(32, 1);
const auto sframe_epoch_secret_2 = bytes(32, 2);

auto pt_out = bytes(plaintext.size());
auto ct_out = bytes(plaintext.size() + Context::max_overhead);

auto member_a = MLSContext(suite, epoch_bits);
auto member_b = MLSContext(suite, epoch_bits);

// Install epoch 1 and verify round-trip
const auto epoch_id_1 = MLSContext::EpochID(1);
member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1);
member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1);

auto enc =
member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
.unwrap();
auto enc_data = to_bytes(enc);
auto dec = to_bytes(member_b.unprotect(pt_out, enc_data, metadata).unwrap());
CHECK(plaintext == dec);

// Install epoch 2
const auto epoch_id_2 = MLSContext::EpochID(2);
member_a.add_epoch(epoch_id_2, sframe_epoch_secret_2);
member_b.add_epoch(epoch_id_2, sframe_epoch_secret_2);

// Remove only epoch 1 (not purge_before) and verify it fails
member_a.remove_epoch(epoch_id_1);
member_b.remove_epoch(epoch_id_1);

CHECK(member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
.error()
.type() == SFrameErrorType::invalid_parameter_error);
CHECK(member_b.unprotect(pt_out, enc_data, metadata).error().type() ==
SFrameErrorType::invalid_parameter_error);

// Epoch 2 should still work
enc = member_a.protect(epoch_id_2, sender_id, ct_out, plaintext, metadata)
.unwrap();
dec = to_bytes(member_b.unprotect(pt_out, enc, metadata).unwrap());
CHECK(plaintext == dec);

// Re-add epoch 1 with the same secret and verify it works again
member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1);
member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1);

enc = member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
.unwrap();
dec = to_bytes(member_b.unprotect(pt_out, enc, metadata).unwrap());
CHECK(plaintext == dec);
}
Loading