JAX API Reference

The JAX implementation provides GPU-accelerated, differentiable phylogenetic computations at f64 precision. All operations support batched models via optional leading *H dimensions.

Types

Tree

from subby.jax.types import Tree

tree = Tree(
    parentIndex=jnp.array([-1, 0, 0, 1, 1], dtype=jnp.int32),
    distanceToParent=jnp.array([0.0, 0.1, 0.2, 0.15, 0.25]),
)
Field Type Shape Description
parentIndex int32 (R,) Preorder parent indices; parentIndex[0] = -1
distanceToParent float (R,) Branch lengths

DiagModel

from subby.jax.types import DiagModel

model = DiagModel(
    eigenvalues=jnp.array([0.0, -1.333, -1.333, -1.333]),
    eigenvectors=V,  # (A, A) orthogonal matrix
    pi=jnp.array([0.25, 0.25, 0.25, 0.25]),
)
Field Type Shape Description
eigenvalues float (*H, A) Eigenvalues of symmetrized rate matrix
eigenvectors float (*H, A, A) v[a,k] = component $a$ of eigenvector $k$
pi float (*H, A) Equilibrium distribution

IrrevDiagModel

For irreversible rate matrices (non-symmetric, complex eigendecomposition).

Field Type Shape Description
eigenvalues complex128 (*H, A) Complex eigenvalues of rate matrix
eigenvectors complex128 (*H, A, A) Right eigenvectors $V$
eigenvectors_inv complex128 (*H, A, A) $V^{-1}$
pi float (*H, A) Stationary distribution

RateModel

from subby.jax.types import RateModel
Field Type Shape Description
subRate float (*H, A, A) Rate matrix $Q$
rootProb float (*H, A) Equilibrium distribution

Automatically diagonalized when passed to high-level functions. Reversibility is auto-detected.


High-level API

LogLike(alignment, tree, model, maxChunkSize=128)

Compute per-column log-likelihoods via Felsenstein pruning.

Parameters:

Name Type Description
alignment int32 (R, C) Token-encoded alignment
tree Tree Phylogenetic tree
model DiagModel, RateModel, list, or grid Substitution model (see below)
maxChunkSize int Column chunk size for memory control

Returns: (*H, C) float array of log-likelihoods.

Model parameter forms:

Form Description
Single model Same model for all columns and branches
[m_0, m_1, ..., m_{C-1}] Per-column: each column uses its own model
[[m_0, m_1, ..., m_{R-1}]] (1, R) per-row: each branch uses its own model, broadcast to all columns
[[...], [...], ...] (C lists of R) (C, R) grid: different model at every column and branch

The (1, R) form broadcasts the same per-branch configuration to all columns. The (C, R) form allows fully independent models at every (column, branch) pair.

When model is a list of C models, each column uses its own substitution model (per-column substitution matrices). This enables position-specific rates, e.g., from a neural network predicting rates per site.

Example — CNN-predicted per-column rates:

A 1D convolutional network reads the one-hot-encoded leaf sequences of an MSA and predicts per-column rate multipliers for a Jukes-Cantor model. Gradients flow from LogLike through the per-column model list back into the CNN parameters.

import jax
import jax.numpy as jnp
from subby.jax import LogLike
from subby.jax.types import Tree
from subby.jax.models import jukes_cantor_model, scale_model

# --- 1D CNN: leaf one-hots -> per-column rates ---
def conv1d(x, w, b):
    out = jax.lax.conv_general_dilated(
        x[None, ...], w, window_strides=(1,), padding="SAME")[0]
    return out + b[:, None]

def predict_rates(params, leaf_one_hot):
    """(n_leaves * A, C) -> (C,) positive rates."""
    h = jax.nn.relu(conv1d(leaf_one_hot, params["w1"], params["b1"]))
    h = jax.nn.relu(conv1d(h, params["w2"], params["b2"]))
    h = conv1d(h, params["w3"], params["b3"])  # (1, C)
    return jax.nn.softplus(h[0])

# --- Loss: negative log-likelihood with per-column rates ---
def loss_fn(params, alignment, tree, base_model, leaf_idx):
    A = base_model.pi.shape[0]
    leaves = alignment[leaf_idx]                         # (n_leaves, C)
    oh = jax.nn.one_hot(leaves, A)                       # (n_leaves, C, A)
    oh_input = oh.transpose(0, 2, 1).reshape(-1, alignment.shape[1])
    rates = predict_rates(params, oh_input)              # (C,)
    models = [scale_model(base_model, rates[c])
              for c in range(rates.shape[0])]
    return -jnp.sum(LogLike(alignment, tree, models))

# --- Training loop (SGD, overfitting one example) ---
grad_fn = jax.grad(loss_fn)
for step in range(500):
    grads = grad_fn(params, alignment, tree, base_model, leaf_idx)
    params = jax.tree.map(lambda p, g: p - 1e-3 * g, params, grads)

See examples/conv_rate_prediction.py for a complete runnable script that simulates an alignment with slow/fast columns and trains the CNN to recover the rate pattern.

Example — fitting per-branch transition/transversion ratio:

Each branch of the tree can have its own kappa (transition/transversion ratio) while sharing the same equilibrium frequencies. Wrap R models in a single-element list [[m_0, ..., m_{R-1}]] to create a (1, R) per-row grid.

import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize
from subby.jax import LogLike, BranchCounts
from subby.jax.types import Tree
from subby.jax.models import hky85_diag

# Star tree: 4 leaves, all branches directly from root
tree = Tree(
    parentIndex=jnp.array([-1, 0, 0, 0, 0], dtype=jnp.int32),
    distanceToParent=jnp.array([0.0, 0.15, 0.25, 0.1, 0.2]),
)
R, A = 5, 4
pi = jnp.array([0.3, 0.2, 0.25, 0.25])

# --- Simulate 10000 columns with known per-branch kappas ---
true_kappas = [2.0, 1.5, 4.0, 0.8, 3.0]
C = 10000
key = jax.random.PRNGKey(42)

# Substitution matrix M(t) = exp(Q*t) from eigendecomposition
def sub_matrix(model, t):
    V, mu = model.eigenvectors, model.eigenvalues
    S = jnp.einsum('ak,k,bk->ab', V, jnp.exp(mu * t), V)
    sp = jnp.sqrt(model.pi)
    return S * (1.0 / sp)[:, None] * sp[None, :]

true_models = [hky85_diag(k, pi) for k in true_kappas]
Ms = [sub_matrix(true_models[r], float(tree.distanceToParent[r]))
      for r in range(R)]

# Propagate states root → leaves
key, k1 = jax.random.split(key)
states = [jax.random.categorical(k1, jnp.log(pi), shape=(C,))]
for n in range(1, R):
    key, k = jax.random.split(key)
    states.append(jax.random.categorical(
        k, jnp.log(jnp.clip(Ms[n], 1e-30))[states[0]], axis=-1,
    ))

# Alignment: leaves observed, root unobserved (token A)
alignment = jnp.stack([jnp.full(C, A, dtype=jnp.int32)] +
                       [states[n] for n in range(1, R)])

# --- Fit kappa per branch by maximum likelihood ---
def neg_ll(log_kappas):
    models = [hky85_diag(float(np.exp(lk)), pi) for lk in log_kappas]
    return -float(jnp.sum(LogLike(alignment, tree, [models])))

result = minimize(neg_ll, x0=np.log(2.0) * np.ones(R), method='Nelder-Mead')
fitted_kappas = np.exp(result.x)

# Branch 0 (root, t=0) is unidentifiable; leaf branches recover true values
for r in range(1, R):
    print(f"Branch {r}: true κ={true_kappas[r]:.1f}, fitted κ={fitted_kappas[r]:.2f}")
# Branch 1: true κ=1.5, fitted κ≈1.36
# Branch 2: true κ=4.0, fitted κ≈3.75
# Branch 3: true κ=0.8, fitted κ≈0.88
# Branch 4: true κ=3.0, fitted κ≈2.74

# Per-branch substitution counts at the fitted values
models_row = [hky85_diag(k, pi) for k in fitted_kappas]
bc = BranchCounts(alignment, tree, [models_row])  # (R, 4, 4, C)

LogLikeCustomGrad(alignment, tree, model, maxChunkSize=128)

Like LogLike but with a custom VJP for faster distance gradients.

Uses the Fisher identity: the gradient of log-likelihood w.r.t. branch lengths equals a contraction of expected substitution counts, computed via the downward pass and eigenbasis projection without tracing through the full computation graph. The forward pass is identical to LogLike; only the backward pass differs.

Parameters: Same as LogLike (single model only, not per-column).

Returns: (*H, C) float array of log-likelihoods.

Example:

import jax
from subby.jax import LogLikeCustomGrad
from subby.jax.types import Tree

def loss(distances):
    tree = Tree(parentIndex=parent_idx, distanceToParent=distances)
    return jnp.sum(LogLikeCustomGrad(alignment, tree, model))

# Gradient via Fisher identity (faster than autograd)
grad = jax.grad(loss)(distances)

Counts(alignment, tree, model, maxChunkSize=128, f81_fast_flag=False)

Compute expected substitution counts and dwell times per column.

Parameters:

Name Type Description
alignment int32 (R, C) Token-encoded alignment
tree Tree Phylogenetic tree
model model, list, or grid Substitution model (same forms as LogLike)
maxChunkSize int Column chunk size
f81_fast_flag bool Use $O(CRA^2)$ fast path (F81/JC only; not with per-row models)

Returns: (*H, A, A, C) float tensor. Diagonal entries are dwell times $E[w_i(c)]$; off-diagonal entries are substitution counts $E[s_{ij}(c)]$.

BranchCounts(alignment, tree, model, maxChunkSize=128, f81_fast_flag=False)

Compute per-branch expected substitution counts and dwell times per column. Returns the same quantities as Counts but broken down per branch rather than summed.

Parameters:

Name Type Description
alignment int32 (R, C) Token-encoded alignment
tree Tree Phylogenetic tree
model model, list, or grid Substitution model (same forms as LogLike)
maxChunkSize int Column chunk size
f81_fast_flag bool Use $O(CRA^2)$ fast path (F81/JC only; not with per-row models)

Returns: (*H, R, A, A, C) float tensor. Branch 0 (root) is zeros. Diagonal entries are dwell times; off-diagonal entries are substitution counts. Summing over the R axis recovers Counts.

ExpectedCounts(model, t)

Expected substitution counts and dwell times for a single CTMC branch, independent of any alignment or tree.

Computes $\mathbb{E}[N_{i \to j}(t) \mid X(0)=a, X(t)=b]$ (off-diagonal) and $\mathbb{E}[T_i(t) \mid X(0)=a, X(t)=b]$ (diagonal) for all $(a, b, i, j)$.

Parameters:

Name Type Description
model DiagModel, IrrevDiagModel, or RateModel Substitution model
t float Branch length

Returns: (*H, A, A, A, A) float tensor. result[..., a, b, i, j] is the expected number of $i \to j$ substitutions (off-diagonal) or the expected dwell time in state $i$ (diagonal), conditioned on endpoint states $a$ and $b$.

Properties: - Dwell times sum to $t$: $\sum_i \text{result}[a, b, i, i] = t$ for every reachable $(a, b)$. - All entries are non-negative. - At $t = 0$, all entries are zero.

expected_counts_eigen(eigenvalues, eigenvectors, pi, t)

Inner function for reversible models. Takes pre-computed eigendecomposition, so it can be called repeatedly for different $t$ without re-diagonalizing.

Parameter Type Shape Description
eigenvalues float (*H, A) Eigenvalues of symmetrized rate matrix
eigenvectors float (*H, A, A) Orthogonal eigenvectors
pi float (*H, A) Equilibrium distribution
t float scalar Branch length

expected_counts_eigen_irrev(eigenvalues, eigenvectors, eigenvectors_inv, pi, t)

Inner function for irreversible models. Takes pre-computed complex eigendecomposition.

Parameter Type Shape Description
eigenvalues complex128 (*H, A) Complex eigenvalues
eigenvectors complex128 (*H, A, A) Right eigenvectors $V$
eigenvectors_inv complex128 (*H, A, A) $V^{-1}$
pi float (*H, A) Stationary distribution
t float scalar Branch length

RootProb(alignment, tree, model, maxChunkSize=128)

Compute posterior root state distribution per column.

$$q_a(c) = \frac{\pi_a \cdot U^{(0)}_a(c)}{P(x_c)}$$

Parameters:

Name Type Description
alignment int32 (R, C) Token-encoded alignment
tree Tree Phylogenetic tree
model model, list, or grid Substitution model (same forms as LogLike)
maxChunkSize int Column chunk size

Returns: (*H, A, C) float array. Sums to 1 over the $A$ dimension for each column.

MixturePosterior(alignment, tree, models, log_weights, maxChunkSize=128)

Compute posterior over mixture components per column.

$$P(k \mid x_c) = \text{softmax}_k(\log P(x_c \mid k) + \log w_k)$$

Parameters:

Name Type Description
alignment int32 (R, C) Token-encoded alignment
tree Tree Phylogenetic tree
models list[DiagModel] $K$ substitution models (e.g., rate-scaled)
log_weights float (K,) Log prior weights
maxChunkSize int Column chunk size

Returns: (K, C) float array of posterior probabilities.


Model constructors

hky85_diag(kappa, pi)

HKY85 model with closed-form eigendecomposition.

Parameter Type Description
kappa scalar Transition/transversion ratio
pi float (4,) Equilibrium frequencies $[\pi_A, \pi_C, \pi_G, \pi_T]$

Returns: DiagModel with 4 distinct eigenvalues.

jukes_cantor_model(A)

Jukes-Cantor model for an $A$-state alphabet. Equal rates, uniform equilibrium.

Parameter Type Description
A int Alphabet size

Returns: DiagModel with eigenvalues $\mu_0 = 0$, $\mu_k = -A/(A-1)$ for $k \geq 1$.

f81_model(pi)

F81 model: $R_{ij} = \mu \cdot \pi_j$ for $i \neq j$, normalized to expected rate 1.

Parameter Type Description
pi float (A,) Equilibrium frequencies

Returns: DiagModel with eigenvalues $\mu_0 = 0$, $\mu_k = -\mu$ for $k \geq 1$.

gamma_rate_categories(alpha, K)

Yang (1994) discretized gamma rate categories using quantile medians.

Parameter Type Description
alpha scalar Shape parameter (lower = more rate variation)
K int Number of categories

Returns: (rates, weights) — each (K,). Rates are mean-normalized; weights are uniform $1/K$.

gy94_model(omega, kappa, pi=None)

Goldman-Yang (1994) codon substitution model. Operates on 61 sense codons.

Rate matrix: - $Q_{ij} = 0$ if codons differ at more than 1 nucleotide position - $Q_{ij} = \pi_j \cdot \kappa^{\mathbb{1}[\text{transition}]} \cdot \omega^{\mathbb{1}[\text{nonsynonymous}]}$ - Diagonal: $Q_{ii} = -\sum_{j \neq i} Q_{ij}$ - Normalized so $-\sum_i \pi_i Q_{ii} = 1$

Parameter Type Description
omega scalar dN/dS ratio (Ka/Ks)
kappa scalar Transition/transversion ratio
pi float (61,) or None Codon equilibrium frequencies (default: uniform $1/61$)

Returns: DiagModel with $A = 61$ states.

from subby.jax.models import gy94_model
model = gy94_model(omega=0.5, kappa=2.0)
# DiagModel with 61 sense codons

scale_model(model, rate_multiplier)

Scale eigenvalues by a rate multiplier. If rate_multiplier is (K,), adds $K$ as a leading batch dimension.


Preset models

cherryml_siteRM()

Load the CherryML 400x400 site-pair coevolution model (Prillo et al., Nature Methods 2023). Returns a DiagModel with $A = 400$ states representing pairs of amino acids at structurally contacting sites.

State ordering: pair $(i, j) \to i \cdot 20 + j$ using the ARNDCQEGHILKMFPSTWYV alphabet.

Returns: DiagModel with $A = 400$ states.

from subby.jax.presets import cherryml_siteRM
model_400 = cherryml_siteRM()
# model_400.pi has shape (400,)

Format utilities

genetic_code()

Return the standard genetic code as a structured dict. Codons are in ACGT lexicographic order (AAA, AAC, AAG, ..., TTT). Stop codons (TAA=48, TAG=50, TGA=56) are marked with '*'.

Returns: dict with:

Key Type Description
codons list[str] 64 codon strings
amino_acids list[str] 64 amino acid letters (stop = '*')
sense_mask (64,) bool True for sense codons
sense_indices (61,) int Indices of sense codons in 0..63
codon_to_sense (64,) int Maps codon index to sense index (stop -> -1)
sense_codons list[str] 61 sense codon strings
sense_amino_acids list[str] 61 amino acid letters
from subby.formats import genetic_code
gc = genetic_code()
print(gc['sense_codons'][:5])  # ['AAA', 'AAC', 'AAG', 'AAT', 'ACA']

codon_to_sense(alignment, A=64)

Remap a 64-codon tokenized alignment to 61-sense-codon tokens. Stop codons become the gap token. Unobserved and gap tokens are remapped to the new alphabet size.

Parameter Type Description
alignment int32 (N, C) Tokens 0..63 for codons, 64 for ungapped-unobserved, 65 for gap
A int Input codon alphabet size (default 64)

Returns: dict with alignment (int32 (N, C) with $A_\text{sense} = 61$), A_sense (61), alphabet (list of 61 sense codon strings).

split_paired_columns(alignment, paired_columns, A=20)

Split an alignment into paired and single-column alignments. For coevolution models that operate on pairs of columns (e.g., CherryML SiteRM with $A = 400 = 20 \times 20$ amino acid pairs). Internally uses kmer_tokenize with k=2 column tuples for pairs and k=1 for singles.

Parameter Type Description
alignment int32 (N, C) Token-encoded alignment with $A$-state tokens
paired_columns list[(int, int)] List of (col_i, col_j) tuples
A int Single-column alphabet size (default 20 for amino acids)

Returns: dict with:

Key Type Description
paired_alignment int32 (N, P) $A_\text{paired} = A \times A$ states
singles_alignment int32 (N, S) $A_\text{singles} = A$ states
paired_columns list[(int, int)] Echoed back
single_columns list[int] Columns not in any pair
A_paired int $A \times A$
A_singles int $A$
paired_index KmerIndex Tuple ↔ index mapping for paired columns
singles_index KmerIndex Tuple ↔ index mapping for single columns

merge_paired_columns(paired_posterior, singles_posterior, split_info)

Reassemble per-column posteriors from paired and single results. Marginalizes the $A_\text{paired} = A \times A$ dimensional paired posteriors into two $A$-dimensional single-column posteriors, then reassembles into the original column order.

Parameter Type Description
paired_posterior float (A_paired, P) Posterior for paired columns
singles_posterior float (A_singles, S) Posterior for single columns
split_info dict Output from split_paired_columns

Returns: (A, C) array — posterior for all columns in original order.

K-mer tokenization

See the Oracle API reference for full documentation of KmerIndex, sliding_windows, all_column_ktuples, and kmer_tokenize. These functions are implemented in subby.formats and re-exported from both subby.jax and subby.oracle.

from subby.formats import kmer_tokenize, sliding_windows, all_column_ktuples, KmerIndex

Padding utilities

pad_alignment(alignment, bin_size=128)

Pad alignment columns to the next multiple of bin_size with gap tokens (-1). Gap-padded columns are mathematically neutral (logL = 0, zero counts, root prior unchanged), so padding avoids JAX recompilation when C varies across inputs.

Parameters:

Name Type Description
alignment int32 (R, C) Token-encoded alignment
bin_size int Round C up to the next multiple of this (default 128)

Returns: (padded_alignment, C_original) — the padded (R, C_padded) alignment and the original column count.

unpad_columns(result, C_original)

Strip padding columns from a result array: result[..., :C_original].

Parameters:

Name Type Description
result array (..., C_padded) Output from a padded computation
C_original int Original column count from pad_alignment

Returns: (..., C_original) array.

Example — JIT-friendly binning:

from subby.jax import LogLike, pad_alignment, unpad_columns

padded, C_orig = pad_alignment(alignment, bin_size=64)
ll = LogLike(padded, tree, model)       # shape reused across similar C
ll = unpad_columns(ll, C_orig)          # back to original C

InsideOutside

InsideOutside(alignment, tree, model, maxChunkSize=128)

Runs the inside (upward) and outside (downward) passes once and stores the resulting DP tables, enabling efficient queries for log-likelihoods, expected substitution counts, node state posteriors, and branch endpoint joint posteriors without recomputation.

Parameters:

Name Type Description
alignment int32 (R, C) Token-encoded alignment
tree Tree Phylogenetic tree
model DiagModel, IrrevDiagModel, RateModel, or list Substitution model
maxChunkSize int Column chunk size

Properties:

Name Type Description
log_likelihood (*H, C) float Per-column log-likelihoods

Methods:

counts(f81_fast_flag=False, branch_mask="auto")

Expected substitution counts and dwell times, reusing stored DP tables.

Returns: (*H, A, A, C) float tensor.

branch_counts(f81_fast_flag=False)

Per-branch expected substitution counts and dwell times, reusing stored DP tables.

Returns: (*H, R, A, A, C) float tensor. Branch 0 (root) is zeros.

node_posterior(node=None)

Posterior state distribution at node(s).

For root: $P(X_0 = a \mid \text{data}) \propto \pi_a \cdot U^{(0)}_a(c)$

For non-root: $P(X_n = j \mid \text{data}) \propto \left[\sum_a D^{(n)}_a(c) \cdot M^{(n)}_{aj}\right] \cdot U^{(n)}_j(c)$

Argument Type Description
node int or None Node index, or None for all nodes

Returns: (*H, A, C) if node is int; (*H, R, A, C) if None.

branch_posterior(node=None)

Joint posterior of parent-child states on a branch.

$$P(X_{\text{parent}(n)}=i,\, X_n=j \mid \text{data}, c) \propto D^{(n)}_i(c) \cdot M^{(n)}_{ij} \cdot U^{(n)}_j(c)$$

Argument Type Description
node int or None Child node index (must be > 0), or None for all

Returns: (*H, A, A, C) if node is int; (*H, R, A, A, C) if None. Branch 0 is zeros.

Example:

from subby.jax import InsideOutside

io = InsideOutside(alignment, tree, model)

ll = io.log_likelihood                  # (*H, C)
root_post = io.node_posterior(0)        # (*H, A, C)
all_posts = io.node_posterior()         # (*H, R, A, C)
branch_joint = io.branch_posterior(3)   # (*H, A, A, C)
counts = io.counts()                    # (*H, A, A, C)
per_branch = io.branch_counts()         # (*H, R, A, A, C)

Low-level functions

diagonalize_rate_matrix(subRate, rootProb)

Convert a RateModel to DiagModel via eigendecomposition of the symmetrized rate matrix.

compute_sub_matrices(model, distanceToParent)

Compute transition probability matrices $M_{ij}(t_n)$ for each branch.

Returns: (*H, R, A, A) — rows sum to 1, $M(0) = I$.

upward_pass(alignment, tree, subMatrices, rootProb, maxChunkSize, per_column=False)

Felsenstein pruning (postorder, leaves to root) via jax.lax.scan.

When per_column=True, subMatrices has shape (*H, R, C, A, A) — a different substitution matrix per column. Default: (*H, R, A, A).

Returns: (U, logNormU, logLike) where: - U: (*H, R, C, A) rescaled inside vectors - logNormU: (*H, R, C) log-normalizers - logLike: (*H, C) per-column log-likelihoods

downward_pass(U, logNormU, tree, subMatrices, rootProb, alignment, per_column=False)

Outside algorithm (preorder, root to leaves) via jax.lax.scan.

When per_column=True, subMatrices has shape (*H, R, C, A, A). Default: (*H, R, A, A).

Returns: (D, logNormD) where: - D: (*H, R, C, A) rescaled outside vectors - logNormD: (*H, R, C) log-normalizers

compute_J(eigenvalues, distanceToParent)

$J$ interaction matrix for eigensubstitution accumulation.

$$J_{kl}(t) = \begin{cases} t \cdot e^{\mu_k t} & \text{if } \mu_k \approx \mu_l \\ \frac{e^{\mu_k t} - e^{\mu_l t}}{\mu_k - \mu_l} & \text{otherwise} \end{cases}$$

Returns: (*H, R, A, A).

eigenbasis_project(U, D, model)

Project inside/outside vectors into the eigenbasis.

Returns: (U_tilde, D_tilde) — each (*H, R, C, A).

accumulate_C(D_tilde, U_tilde, J, logNormU, logNormD, logLike, parentIndex)

Sum eigenbasis contributions over all non-root branches.

Returns: (*H, A, A, C) eigenbasis counts.

back_transform(C, model)

Transform eigenbasis counts to natural-basis substitution counts and dwell times.

Returns: (*H, A, A, C).

accumulate_C_per_branch(...) / back_transform_per_branch(...)

Per-branch variants of accumulate_C and back_transform. Instead of summing over branches, each branch's contribution is stored separately.

Returns: (*H, R, A, A, C).

f81_counts(U, D, logNormU, logNormD, logLike, distances, pi, parentIndex)

$O(CRA^2)$ direct computation for F81/JC models, avoiding the eigenbasis.

Returns: (*H, A, A, C).

f81_counts_per_branch(U, D, logNormU, logNormD, logLike, distances, pi, parentIndex)

Per-branch variant of f81_counts.

Returns: (*H, R, A, A, C).

mixture_posterior(log_likes, log_weights)

Numerically stable softmax over mixture components.

Returns: (K, C).