Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
617 changes: 452 additions & 165 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ members = [
]

[profile.release]
lto = true
debug = true
lto = false
debug = false
62 changes: 18 additions & 44 deletions benchmark_test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,23 @@ def gen_data_single_model(t_dtype, dim, seed):
sqrt_pi = sqrt_pi[0:1]
return S, sqrt_pi

def helper_test(dtype, dim : int, gradients: bool, single_model: bool = False, gpu: bool = False):
if dtype == "f32":
t_dtype = torch.float32
np_dtype = np.float32
else:
t_dtype = torch.float64
np_dtype = np.float64
def helper_test(dim : int, gradients: bool, single_model: bool = False):
t_dtype = torch.float64
np_dtype = np.float64

if dtype == "f32":
rtol = 1e-2
atol = 1e-2
else:
rtol = 1e-4
atol = 1e-4
rtol = 1e-4
atol = 1e-4

torch_tree, parent_list, branch_lengths, leaf_log_p = gen_tree(t_dtype, dim)

leaf_pl = torch.exp(leaf_log_p)

if single_model:
data = [gen_data_single_model(t_dtype, dim, seed) for seed in range(10)]
else:
data = [gen_data(t_dtype, dim, seed) for seed in range(10)]

rust_tree = phylo_grad.FelsensteinTree(parent_list, branch_lengths.astype(np_dtype), leaf_log_p.numpy(), 1e-4, gpu)
rust_tree = phylo_grad.FelsensteinTree(parent_list, branch_lengths.astype(np_dtype), leaf_pl.numpy(), 1e-4)


for i in range(10):
Expand All @@ -65,10 +59,7 @@ def helper_test(dtype, dim : int, gradients: bool, single_model: bool = False, g


result = rust_tree.calculate_gradients(S.numpy(), sqrt_pi.numpy())
if gpu == False:
likelihoods = rust_tree.calculate_log_likelihoods(S.numpy(), sqrt_pi.numpy())
assert(np.allclose(likelihoods, result['log_likelihood'], rtol=1e-5))


assert(np.allclose(result['log_likelihood'], torch_logP.numpy(), rtol=rtol))

if gradients:
Expand All @@ -93,31 +84,14 @@ def helper_test(dtype, dim : int, gradients: bool, single_model: bool = False, g
assert(np.allclose(result['grad_s'], torch_S_grad, rtol=rtol, atol=atol))

def test_likelihood():
helper_test("f32", 4, False)
helper_test("f32", 20, False)
helper_test("f64", 4, False)
helper_test("f64", 20, False)

def test_gpu_likelihood():
helper_test("f64", 4, False, gpu = True)
helper_test("f64", 20, False, gpu= True)
helper_test("f32", 4, False, gpu= True)
helper_test("f32", 20, False, gpu= True)
helper_test(4, False, True)
helper_test(4, False, False)
helper_test(20, False, True)
helper_test(20, False, False)

def test_grads():
helper_test("f32", 4, True)
helper_test("f32", 20, True)
helper_test("f64", 4, True)
helper_test("f64", 20, True)

def test_grads_single_model():
helper_test("f64", 4, True, single_model=True)
helper_test("f64", 20, True, single_model=True)
helper_test("f32", 4, True, single_model=True)
helper_test("f32", 20, True, single_model=True)

def test_gpu_grads():
helper_test("f64", 4, True, gpu = True)
helper_test("f64", 20, True, gpu= True)
helper_test("f32", 4, True, gpu= True)
helper_test("f32", 20, True, gpu= True)
helper_test(4, True, True)
helper_test(4, True, False)
helper_test(20, True, True)
helper_test(20, True, False)

2 changes: 1 addition & 1 deletion benchmark_test/test_tree.newick
Original file line number Diff line number Diff line change
@@ -1 +1 @@
((((Tip_21:0.6650676023435903,Tip_22:0.024885436124515284):0.2056669529961157,((Tip_37:0.0337728838700196,Tip_38:0.9484958108725318):0.7273031959389747,Tip_36:0.4958302887613361):0.33757637305798605):0.8499954308656537,(Tip_4:0.8141476837951772,Tip_5:0.7402798051460897):0.70741682436221):0.813355,((((((Tip_83:0.17643378492960254,Tip_84:0.7953558100843726):0.14557382727056703,(Tip_43:0.8104945919986606,Tip_44:0.31449340792858616):0.5734673065270548):0.5607844652217767,Tip_18:0.17900274566087643):0.5063065847624981,(((Tip_32:0.6491605888497303,Tip_33:0.3429582140411475):0.14180550845435472,Tip_28:0.08365819037239011):0.5365343800267653,Tip_24:0.8800154043806051):0.998762888840805):0.8304610217387793,(Tip_0:0.7039506434317381,(Tip_25:0.7331478876918618,(Tip_27:0.9441692227658312,((Tip_57:0.38896701363191544,Tip_58:0.13379392008599617):0.8754816391828532,(Tip_60:0.352209387501356,Tip_61:0.3956010586841662):0.4583533711457218):0.5460177635944152):0.6361110453676238):0.6602589686792034):0.027727204798649534):0.6704649990971151,(((Tip_3:0.21104851572654998,(((((((((Tip_73:0.9483105414473861,Tip_74:0.6626788692704468):0.7822936616109846,Tip_59:0.12249378702437955):0.08222970026113245,(Tip_96:0.4290800932938288,Tip_97:0.24924474555329348):0.9179999611656243):0.7548297895532528,(Tip_98:0.9167802235932272,Tip_99:0.7204509004585486):0.8287579648322002):0.05271172049884924,((Tip_81:0.5413926753801398,Tip_82:0.4118750056383439):0.9486624612083683,Tip_39:0.29899889747606284):0.055068995802558016):0.2501971861428551,((Tip_86:0.0380499276137973,Tip_87:0.4256548943232936):0.07123509097099835,((((Tip_55:0.5710466520679875,Tip_56:0.14775864025032526):0.32953437140242836,Tip_47:0.7864549772499915):0.4280637902774943,Tip_45:0.13776127824637646):0.16954936978362714,Tip_23:0.17176851427365566):0.18988101382133601):0.9138259583612118):0.09048036473551212,(((Tip_92:0.404429134512315,Tip_93:0.047246278915062306):0.06224988908172349,((Tip_75:0.6566656102220733,Tip_76:0.6381056061984103):0.29787087139597435,Tip_40:0.5279535103620298):0.8286398644888646):0.19446009582787896,(Tip_34:0.26670755857819356,(Tip_69:0.10832233286814032,Tip_70:0.9365264322108605):0.41936902078708815):0.27879362246002265):0.5567981824116087):0.2694506279899209,Tip_6:0.9312226738371098):0.9215301310000562,(Tip_10:0.4514212322743643,Tip_11:0.18865283598284918):0.8275734698074786):0.23518611049075663):0.992212745439618,((Tip_19:0.8295636279502516,Tip_20:0.34263641161658875):0.9114454035382507,((((Tip_51:0.9220107008370401,Tip_52:0.8176402629230148):0.006951888793413944,(((Tip_94:0.3251911236094175,Tip_95:0.9900649501189612):0.2558363577627372,Tip_85:0.5404121010340436):0.9010280119121413,Tip_29:0.7328316282669919):0.9489623456648526):0.669567713578707,Tip_26:0.8016779879698788):0.6729191308869212,Tip_17:0.29215200135456015):0.31179956452592034):0.2543242891578164):0.5584080530741793,(Tip_66:0.8947516564428961,Tip_67:0.5767430884730821):0.2489036977476457):0.9656347634788244):0.239977,((((Tip_2:0.9464313983922135,(Tip_88:0.41442648404621196,Tip_89:0.5101163097014396):0.5466571358298574):0.6334947392181386,((((((Tip_12:0.11158676296777406,Tip_13:0.7279909147393833):0.2242187756664422,(Tip_41:0.7690871167673865,Tip_42:0.7384233696004908):0.9385579338831107):0.5705174119616021,(Tip_7:0.7360096495183802,Tip_8:0.25431554405864565):0.5988161405333698):0.9687384235590413,((Tip_14:0.49387953934218143,(Tip_62:0.9806096658151946,Tip_63:0.04665946024200317):0.21648740028859653):0.12806441473128458,(((Tip_71:0.4820889062355187,Tip_72:0.669077938124829):0.6907792707402931,Tip_46:0.7915022800264976):0.3224651326307264,Tip_1:0.37807097934208633):0.08262352954341846):0.5415349680132456):0.5368870623401885,(((Tip_9:0.30112516594260985,(Tip_30:0.8010675030911405,Tip_31:0.2197855899426761):0.2979712085110747):0.20561375454872075,(Tip_79:0.322022616870698,Tip_80:0.8405404258659477):0.3736763838229683):0.15489065294790658,((Tip_77:0.1631714880198354,Tip_78:0.14347186085240984):0.6350189977427474,((((Tip_90:0.8211951990555387,Tip_91:0.30263563690395334):0.6502375388726693,Tip_68:0.7812120071752737):0.36830312938864523,Tip_48:0.7430015399338462):0.17280797268648893,Tip_35:0.4635593017282353):0.7149095458288683):0.5852674316858171):0.8996082718667177):0.17732798167953295,((Tip_49:0.33471854305440063,Tip_50:0.11184498641576858):0.8048081839696574,(Tip_53:0.6764255828447392,Tip_54:0.11645530351988755):0.47997626175532904):0.7911201160390272):0.806790252892112):0.43691677945398477,(Tip_64:0.6545202475677102,Tip_65:0.8679242823464907):0.6299303286521293):0.7947232572644647,(Tip_15:0.7247905356618102,Tip_16:0.7397524628980965):0.10514061923847424):1.13067);
(((Tip_3:0.10805328350042452,Tip_4:0.3315005520372256):0.7775108391547321,(((Tip_10:0.9543578556535617,Tip_11:0.46520848427708783):0.19055443061683935,Tip_5:0.8608787164185184):0.2282830452051495,(Tip_8:0.45877603959910257,Tip_9:0.2188740043797414):0.977427116635245):0.8761171479386695):1.02879,((((Tip_14:0.3495434398130265,Tip_15:0.7072552650535753):0.44843039151645087,(Tip_6:0.8903069753526421,Tip_7:0.11856013796286433):0.6507572690541433):0.7600580845054129,Tip_1:0.2230762517502976):0.04289178403051824,Tip_0:0.7254337418961002):0.143964,((Tip_12:0.535570065080058,Tip_13:0.9288624624322107):0.771915926537751,Tip_2:0.9590719826516791):0.939461);
3 changes: 2 additions & 1 deletion gtr_optimize/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ phylotree = "0.1.3"
lazy_static = "1.5.0"
gosh-lbfgs = "0.1.0"
seq_io = "0.3.4"
nalgebra = "0.33.2"
nalgebra = "0.34.1"
num-traits = {version="0.2.19", default-features=false, features = ["libm"]}
rayon = "1.11.0"
logsumexp = "0.1.0"
45 changes: 28 additions & 17 deletions gtr_optimize/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
//! It has a global matrix mode a per column matrix mode, caled "local" here.

use lazy_static::lazy_static;
use logsumexp::LogSumExp;
use nalgebra as na;
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
use std::collections::HashMap;

use phylo_grad::{FelsensteinTree, FloatTrait};
use phylo_grad::FelsensteinTree;

lazy_static! {
static ref AMINO_MAPPING: HashMap<u8, u8> = {
Expand Down Expand Up @@ -39,16 +40,15 @@ lazy_static! {
};
}

pub fn seq2pll<F: FloatTrait>(seq: impl Iterator<Item = u8>) -> Vec<na::SVector<F, 20>> {
pub fn seq2pll(seq: impl Iterator<Item = u8>) -> Vec<na::SVector<f64, 20>> {
seq.map(|c| *AMINO_MAPPING.get(&c).unwrap_or(&20))
.map(|idx| {
let mut v =
na::SVector::<F, 20>::from_element(<F as FloatTrait>::from_f64(f64::NEG_INFINITY));
let mut v = na::SVector::<f64, 20>::from_element(0.0);
if idx < 20 {
v[idx as usize] = F::zero();
v[idx as usize] = 1.0;
} else {
for i in 0..20 {
v[i] = F::zero();
v[i] = 1.0;
}
}
v
Expand All @@ -60,7 +60,7 @@ pub fn process_newick_alignment(
newick: &str,
sequences: &HashMap<String, Vec<u8>>,
) -> (
phylo_grad::FelsensteinTree<f64, 20>,
phylo_grad::FelsensteinTree<20>,
Vec<Vec<na::SVector<f64, 20>>>,
) {
let tree = phylotree::tree::Tree::from_newick(newick).unwrap();
Expand Down Expand Up @@ -112,12 +112,23 @@ pub fn process_newick_alignment(
column_seq[new_idx as usize] = seq[i];
}
}
leaf_pll.push(seq2pll::<f64>(column_seq.into_iter()));
leaf_pll.push(seq2pll(column_seq.into_iter()));
}

let felsenstein = FelsensteinTree::<f64, 20>::new(&parents, &distances);
let felsenstein = FelsensteinTree::<20>::new(&parents, &distances);
(felsenstein, leaf_pll)
}
/// Numerical stable softmax
pub fn softmax<const N: usize>(x: &na::SVector<f64, N>) -> na::SVector<f64, N> {
let x_max = x.max();

let result = x.add_scalar(-x_max);

let mut result = result.map(|x| num_traits::Float::exp(x));

result /= result.sum();
result
}

pub fn optimize_gtr_local(newick: &str, sequences: &HashMap<String, Vec<u8>>) -> f64 {
let (felsenstein, mut leaf_pll) = process_newick_alignment(newick, sequences);
Expand All @@ -135,7 +146,7 @@ pub fn optimize_gtr_local(newick: &str, sequences: &HashMap<String, Vec<u8>>) ->
}

fn optimize_gtr_single_side(
felsenstein: &FelsensteinTree<f64, 20>,
felsenstein: &FelsensteinTree<20>,
log_pi_init: &[f64],
log_p: &mut [na::SVector<f64, 20>],
) -> f64 {
Expand Down Expand Up @@ -198,7 +209,7 @@ fn optimize_gtr_single_side(

pub fn optimize_gtr_global(newick: &str, sequences: &HashMap<String, Vec<u8>>) -> f64 {
let (mut felsenstein, leaf_pll) = process_newick_alignment(newick, sequences);
felsenstein.bind_leaf_log_p(leaf_pll);
felsenstein.bind_leaf_pl(leaf_pll);

let evaluate = |x: &[f64], g: &mut [f64]| {
let log_R = &x[..190];
Expand Down Expand Up @@ -282,7 +293,7 @@ fn rate_matrix_backward(
let d_logM = -d_logS.sum();

let piRpi_max = data.piRpi.max();
let piRpi_exp: na::SMatrix<f64, 20, 20> = data.piRpi.map(|x| (x - piRpi_max).scalar_exp());
let piRpi_exp: na::SMatrix<f64, 20, 20> = data.piRpi.map(|x| (x - piRpi_max).exp());
let piRpi_sum = piRpi_exp.sum();

let d_piRpi = piRpi_exp.map(|x| x * d_logM / piRpi_sum);
Expand All @@ -304,7 +315,7 @@ fn rate_matrix_backward(
let d_log_pi =
d_log_pi + sqrt_pi_cotangent.component_mul(&data.log_pi.map(|x| 0.5 * (0.5 * x).exp()));

let softmax_log_pi_unorm = phylo_grad::softmax(&data.log_pi_unormalized);
let softmax_log_pi_unorm = softmax(&data.log_pi_unormalized);

for i in 0..20 {
grad_log_pi[i] = d_log_pi[i] - softmax_log_pi_unorm[i] * d_log_pi.sum();
Expand All @@ -316,7 +327,7 @@ fn rate_matrix_backward(
fn rate_matrix(log_R: &[f64], log_pi_unormalized: &[f64]) -> RateMatrixData {
let log_pi_unormalized: na::SVector<f64, 20> =
na::SVector::<f64, 20>::from_iterator(log_pi_unormalized.iter().copied());
let log_pi = log_pi_unormalized.add_scalar(-FloatTrait::logsumexp(log_pi_unormalized.iter()));
let log_pi = log_pi_unormalized.add_scalar(-log_pi_unormalized.iter().ln_sum_exp());

let log_R_mat = {
let mut mat = na::SMatrix::<f64, 20, 20>::zeros();
Expand Down Expand Up @@ -345,7 +356,7 @@ fn rate_matrix(log_R: &[f64], log_pi_unormalized: &[f64]) -> RateMatrixData {
mat
};

let logM = FloatTrait::logsumexp(piRpi.iter());
let logM = piRpi.iter().ln_sum_exp();

let logS = {
let mut mat = na::SMatrix::<f64, 20, 20>::zeros();
Expand All @@ -364,7 +375,7 @@ fn rate_matrix(log_R: &[f64], log_pi_unormalized: &[f64]) -> RateMatrixData {
let sqrt_pi = {
let mut v = na::SVector::<f64, 20>::zeros();
for i in 0..20 {
v[i] = (log_pi[i] * 0.5).scalar_exp();
v[i] = (log_pi[i] * 0.5).exp();
}
v
};
Expand All @@ -373,7 +384,7 @@ fn rate_matrix(log_R: &[f64], log_pi_unormalized: &[f64]) -> RateMatrixData {
let mut mat = na::SMatrix::<f64, 20, 20>::zeros();
for i in 0..20 {
for j in 0..20 {
mat[(i, j)] = logS[(i, j)].scalar_exp();
mat[(i, j)] = logS[(i, j)].exp();
}
}
mat
Expand Down
13 changes: 5 additions & 8 deletions phylo_grad/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
[package]
name = "phylo_grad"
version = "1.1.0"
edition = "2021"
version = "2.0.0"
edition = "2024"
authors = ["Benjamin Lieser <benjamin.lieser@mpinat.mpg.de>", "Jora Belousov <jorabelousov@gmail.com>"]
description = "Fast gradient calculation of the Felsenstein algorithm with respect to the rate matrix"
repository = "https://github.com/soedinglab/phylo_grad"
license = "MIT OR Apache-2.0"
keywords = ["math", "felsenstein", "machine-learning", "bioinformatics"]
categories = ["mathematics", "science::bioinformatics"]
rust-version = "1.85"

[dependencies]
logsumexp = "0.1.0"
nalgebra = "0.33.2"
num-traits = {version="0.2.19", default-features=false, features = ["libm"]}
nalgebra = {version ="0.34.1", features = []}
glam = {version="0.30.10", features = ["fast-math"]}
rayon = "1.10.0"
itertools = "0.13.0"
libm = "0.2.10"
sleef = "0.3.2"
faer = {version="0.23.2", default-features=false}
15 changes: 15 additions & 0 deletions phylo_grad/src/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Changelog
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [2.0.0] — 2026-

### Changes

- The input to phylo_grad are now the probabilties at the leaves and not the log probabilties anymore!
- The implementation keeps partial likelihoods in linspace internally. It rescales if nessesacry to avoid under and overflows. This leads to a factor 2 speedup, more for the single model case. The output is still a log probability (natural log)
- We dropped support for f32
- The parallization of the single model case has been changed to be dramatically more memory efficient.
- The project does not relay on a nighlty compiler anymore.
Loading
Loading