Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 197 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ mod py;

pub type Rank = u32;

// Every distinct 2-byte sequence (256 byte values × 256). Indexed by (byte_a << 8) | byte_b.
const PAIR_TABLE_SIZE: usize = 256 * 256;

use std::collections::BinaryHeap;

#[derive(Eq, PartialEq, Clone, Copy)]
Expand Down Expand Up @@ -44,7 +47,11 @@ struct State {
cur_rank: Rank,
}

fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
fn _byte_pair_merge_large(
ranks: &HashMap<Vec<u8>, Rank>,
piece: &[u8],
pair_table: Option<&[Rank; PAIR_TABLE_SIZE]>,
) -> Vec<Rank> {
let mut state = Vec::with_capacity(piece.len());
state.push(State {
prev: usize::MAX,
Expand All @@ -56,7 +63,16 @@ fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<R

let mut heap = BinaryHeap::with_capacity(piece.len());
for i in 0..piece.len() - 1 {
if let Some(&rank) = ranks.get(&piece[i..i + 2]) {
// When `pair_table` is Some, look up the 2-byte pair rank via flat array index instead
// of the hashmap. Same semantics: `Rank::MAX` in the table means "not in vocab" so skip.
let rank_opt = match pair_table {
Some(table) => {
let r = table[((piece[i] as u16) << 8 | piece[i + 1] as u16) as usize];
(r != Rank::MAX).then_some(r)
}
None => ranks.get(&piece[i..i + 2]).copied(),
};
if let Some(rank) = rank_opt {
heap.push(Merge { start: i, rank });
state[i].next_rank = rank;
}
Expand Down Expand Up @@ -137,17 +153,28 @@ fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<R
result
}

fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
fn _byte_pair_merge(
ranks: &HashMap<Vec<u8>, Rank>,
piece: &[u8],
pair_table: Option<&[Rank; PAIR_TABLE_SIZE]>,
) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);

// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
//
// When `pair_table` is Some, the initial pair scan uses a flat `PAIR_TABLE_SIZE`-entry
// array indexed by the 2-byte pair instead of the hashmap. Subsequent merges (3+ byte
// sequences) always go through the hashmap, since 3+ byte keys don't fit in a u16.
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
let rank = match pair_table {
Some(table) => table[((piece[i] as u16) << 8 | piece[i + 1] as u16) as usize],
None => *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX),
};
if rank < min_rank.0 {
min_rank = (rank, i);
}
Expand Down Expand Up @@ -202,17 +229,39 @@ pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Ran
return vec![ranks[piece]];
}
if piece_len < 100 {
return _byte_pair_merge(ranks, piece)
return _byte_pair_merge(ranks, piece, None)
.windows(2)
.map(|part| ranks[&piece[part[0].0..part[1].0]])
.collect();
}
_byte_pair_merge_large(ranks, piece, None)
}

/// Like [`byte_pair_encode`] but uses a precomputed 2-byte pair lookup table
/// for the initial pair scan. Used internally by `CoreBPE` methods to skip the
/// hashmap lookup for the hot initial scan.
fn byte_pair_encode_with_table(
piece: &[u8],
ranks: &HashMap<Vec<u8>, Rank>,
pair_table: &[Rank; PAIR_TABLE_SIZE],
) -> Vec<Rank> {
let piece_len = piece.len();

if piece_len == 1 {
return vec![ranks[piece]];
}
if piece_len < 100 {
return _byte_pair_merge(ranks, piece, Some(pair_table))
.windows(2)
.map(|part| ranks[&piece[part[0].0..part[1].0]])
.collect();
}
_byte_pair_merge_large(ranks, piece)
_byte_pair_merge_large(ranks, piece, Some(pair_table))
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
assert!(piece.len() > 1);
_byte_pair_merge(ranks, piece)
_byte_pair_merge(ranks, piece, None)
.windows(2)
.map(|part| &piece[part[0].0..part[1].0])
.collect()
Expand Down Expand Up @@ -325,6 +374,10 @@ pub struct CoreBPE {
regex_tls: Vec<Regex>,
special_regex_tls: Vec<Regex>,
sorted_token_bytes: Vec<Vec<u8>>,
/// Precomputed 2-byte pair to rank lookup table (~256 KB, built once at
/// construction). Used by encoding methods to skip the hashmap lookup for
/// the hot initial adjacent-pair scan inside `_byte_pair_merge`.
pair_table: Box<[Rank; PAIR_TABLE_SIZE]>,
}

impl CoreBPE {
Expand Down Expand Up @@ -366,7 +419,11 @@ impl CoreBPE {
let piece = mat.unwrap().as_str().as_bytes();
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
None => ret.extend(&byte_pair_encode_with_table(
piece,
&self.encoder,
&self.pair_table,
)),
}
}
ret
Expand Down Expand Up @@ -418,7 +475,7 @@ impl CoreBPE {
ret.push(*token);
continue;
}
let tokens = byte_pair_encode(piece, &self.encoder);
let tokens = byte_pair_encode_with_table(piece, &self.encoder, &self.pair_table);
last_piece_token_len = tokens.len();
ret.extend(&tokens);
}
Expand Down Expand Up @@ -550,11 +607,12 @@ impl CoreBPE {
// would be a regex split before the UTF-8 truncation point.
// Probably niche enough that no one will ever notice (after all, people didn't
// notice all the big holes in the previous unstable token implementation)
Err(_) => byte_pair_encode(&possibility, &self.encoder),
// Something like the following is intriguing but incorrect:
// Err(e) => self.encode_ordinary(unsafe {
// std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
// }),
Err(_) => {
byte_pair_encode_with_table(&possibility, &self.encoder, &self.pair_table)
} // Something like the following is intriguing but incorrect:
// Err(e) => self.encode_ordinary(unsafe {
// std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
// }),
};
let mut seq = Vec::new();
let mut seq_len = 0;
Expand Down Expand Up @@ -583,13 +641,15 @@ impl CoreBPE {
if unstable_bytes.len() - last_decoded.1 > 0
&& last_decoded.0.is_some_and(|c| c.is_whitespace())
{
let mut reencoded = byte_pair_encode(
let mut reencoded = byte_pair_encode_with_table(
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
&self.encoder,
&self.pair_table,
);
reencoded.extend(byte_pair_encode(
reencoded.extend(byte_pair_encode_with_table(
&unstable_bytes[unstable_bytes.len() - last_decoded.1..],
&self.encoder,
&self.pair_table,
));
completions.insert(reencoded);
}
Expand Down Expand Up @@ -649,6 +709,15 @@ impl CoreBPE {
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();

// Build the 2-byte pair lookup table (~256 KB, sub-millisecond one-time cost).
let mut pair_table: Box<[Rank; PAIR_TABLE_SIZE]> = Box::new([Rank::MAX; PAIR_TABLE_SIZE]);
for (key, &rank) in &encoder {
if key.len() == 2 {
let idx = ((key[0] as u16) << 8 | key[1] as u16) as usize;
pair_table[idx] = rank;
}
}

Ok(Self {
encoder,
special_tokens_encoder,
Expand All @@ -659,6 +728,7 @@ impl CoreBPE {
.map(|_| special_regex.clone())
.collect(),
sorted_token_bytes,
pair_table,
})
}

Expand Down Expand Up @@ -700,3 +770,114 @@ mod tests {
assert_eq!(res, vec![b"ab", b"ab"]);
}
}

/// Tests that the precomputed 2-byte pair lookup table produces output
/// byte-identical to the vanilla hashmap-lookup path. Covers both the
/// linear `_byte_pair_merge` (pieces < 100 bytes) and the heap-based
/// `_byte_pair_merge_large` (pieces >= 100 bytes) code paths.
#[cfg(test)]
mod pair_table_equivalence {
use super::*;

/// Build a small synthetic encoder: all ASCII a-z and 0-9 as single-byte
/// tokens, plus a handful of common 2-byte and 3-byte merges. Enough to
/// exercise real merge dynamics without depending on a downloaded vocab.
fn synthetic_encoder() -> HashMap<Vec<u8>, Rank> {
let mut encoder = HashMap::default();
let mut rank: Rank = 0;
for b in b'a'..=b'z' {
encoder.insert(vec![b], rank);
rank += 1;
}
for b in b'0'..=b'9' {
encoder.insert(vec![b], rank);
rank += 1;
}
for pair in ["th", "he", "in", "er", "an", "re", "on", "at", "es", "or"] {
encoder.insert(pair.as_bytes().to_vec(), rank);
rank += 1;
}
for triple in ["the", "and", "ing", "ion", "for"] {
encoder.insert(triple.as_bytes().to_vec(), rank);
rank += 1;
}
encoder
}

fn build_pair_table(encoder: &HashMap<Vec<u8>, Rank>) -> Box<[Rank; PAIR_TABLE_SIZE]> {
let mut pair_table: Box<[Rank; PAIR_TABLE_SIZE]> = Box::new([Rank::MAX; PAIR_TABLE_SIZE]);
for (key, &rank) in encoder {
if key.len() == 2 {
let idx = ((key[0] as u16) << 8 | key[1] as u16) as usize;
pair_table[idx] = rank;
}
}
pair_table
}

fn check_equivalence(piece: &[u8]) {
let encoder = synthetic_encoder();
let pair_table = build_pair_table(&encoder);
let vanilla = byte_pair_encode(piece, &encoder);
let patched = byte_pair_encode_with_table(piece, &encoder, &pair_table);
assert_eq!(
vanilla,
patched,
"vanilla vs pair-table diverged on piece of length {}",
piece.len(),
);
}

/// Generate an alphabetic piece of the requested length by cycling a..z.
fn alpha_piece(n: usize) -> Vec<u8> {
(0..n).map(|i| b'a' + ((i % 26) as u8)).collect()
}

#[test]
fn equivalence_length_1_direct_lookup() {
// Single byte takes the early-return path; pair table never consulted.
check_equivalence(b"a");
check_equivalence(b"z");
check_equivalence(b"0");
}

#[test]
fn equivalence_short_pieces_linear_path() {
// Pieces < 100 bytes go through `_byte_pair_merge` (linear scan).
check_equivalence(b"th");
check_equivalence(b"the");
check_equivalence(b"that");
check_equivalence(b"information");
check_equivalence(b"theandingionfor");
}

#[test]
fn equivalence_just_under_100b_cutoff() {
check_equivalence(&alpha_piece(50));
check_equivalence(&alpha_piece(98));
check_equivalence(&alpha_piece(99));
}

#[test]
fn equivalence_at_100b_cutoff() {
// 100 bytes is the boundary: dispatches to `_byte_pair_merge_large`.
check_equivalence(&alpha_piece(100));
check_equivalence(&alpha_piece(101));
}

#[test]
fn equivalence_long_pieces_heap_path() {
// Well into the heap-based path.
check_equivalence(&alpha_piece(200));
check_equivalence(&alpha_piece(500));
check_equivalence(&alpha_piece(1000));
}

#[test]
fn equivalence_repeated_pairs() {
// Many identical 2-byte pairs to stress the initial-scan path.
check_equivalence(&[b'x'; 5]);
check_equivalence(&[b'x'; 99]);
check_equivalence(&[b'x'; 150]);
}
}