Skip to content

Slow computation and sometimes biased results with warm starting #155

Description

@jacopok
  • UltraNest version: 4.3.4
  • Python version: 3.11
  • Operating System: Ubuntu 24.04

Summary

I have been experimenting with warm starting on a toy problem, and I found some strange behaviour.
The sampling seems to reach the correct region quickly as expected, but then it takes a lot for it to converge to the right value of $Z$.
Also, sometimes the estimate for $Z$ is biased (true value several sigmas from the estimate).

Description

The toy problem is: an $n$-dimensional Gaussian likelihood with mean $\mu = 0.5$ on every axis, a small $\sigma \sim 10^{-2}$. The prior transform is the identity for the unit cube.
The evidence is then expected to be $Z \approx 1$ (to a very good approximation).

The point I'm making here also shows in 1D, but the script is versatile: it can run in higher dimensions if desired. The run times are reported for a 3D case, the trace plot is in 1D.

I am comparing a regular NS run to runs done with auxiliary likelihoods and priors obtained with get_auxiliary_contbox_parameterization, with the countours obtained from "guesses" as follows:

  • the correct posterior
  • the correct posterior with a reduced variance
  • the correct posterior with an increased variance

In the language of the SuperNest paper by Petrosyan and Handley, the KL divergence between the original prior and the posterior is (in the $\sigma \ll 1$ approximation)
$$\mathcal{D}_{\pi}(\mathcal{P}) \approx -\frac{1}{2} (1+\log(2\pi)) - \log \sigma$$
per dimension, which comes out to about 3.2 nats for $\sigma=10^{-2}$.

With the guesses, on the other hand, we are going from a Gaussian modified prior with width $k \sigma$ to a Gaussian posterior with width $\sigma$, therefore (the same result as here but in nats)
$$\mathcal{D}_{\tilde{\pi}}(\mathcal{P}) = \log k + \frac{1}{2} \left( \frac{1}{k^{2}} - 1 \right)$$

The examples I'm considering are $k = [0.5, 1, 2]$, with corresponding distances $[0.81, 0, 0.32]$ nats.
The prior being equal to the posterior is a degenerate case, of course, but this still indicates that we should expect a good speed up!

Instead, when I run with the auxiliary sampler, the time performance is sometimes worse.

Also, although the evidence errors are indeed smaller (as they should be), in the case of a too-thin prior the evidence is underestimated (and the error is not correctly estimated).

Ultranest 4.3.4
Scenario: Regular sampling
t=29.16s
logZ=-0.11 +- 0.14
error: 0.78 sigmas

Scenario: Underestimated standard deviation
t=36.75s
logZ=-0.45 +- 0.05
error: 8.89 sigmas

Scenario: Correct standard deviation
t=35.25s
logZ=-0.02 +- 0.06
error: 0.32 sigmas

Scenario: Overestimated standard deviation
t=35.40s
logZ=0.02 +- 0.09
error: 0.25 sigmas

The script is as shown below.

This is what happens with frac_remain is set to a low number ($10^{-3}$) in all cases; if this is higher ($0.5$) things are closer to the expectations:

Ultranest 4.3.4
Scenario: Regular sampling
t=18.25s
logZ=-0.06 +- 0.43
error: 0.15 sigmas

Scenario: Underestimated standard deviation
t=3.96s
logZ=-0.56 +- 0.41
error: 1.37 sigmas

Scenario: Correct standard deviation
t=6.34s
logZ=-0.12 +- 0.41
error: 0.29 sigmas

Scenario: Overestimated standard deviation
t=9.72s
logZ=-0.06 +- 0.42
error: 0.15 sigmas

However, this does not clear up the issue: why does the sampler "get stuck" in the same region making very slow progress for the last contributions to the integral?
Here is a trace plot for the same problem, with frac_remain=1e-3 but in 1D, in the correctly estimated standard deviation case.
What is going on? Why are the points getting more spread out?

trace

import ultranest
import numpy as np
from ultranest.hotstart import get_auxiliary_contbox_parameterization, get_auxiliary_problem, get_extended_auxiliary_independent_problem
from time import perf_counter
import string
import matplotlib.pyplot as plt

N_live = 1000
N_post = 1000
N_par = 3
scale = 1e-2
frac_remain=1e-3

param_names = [string.ascii_lowercase[i] for i in range(N_par)]

def loglike(par):
    return -0.5 * np.sum((par-0.5)**2, axis=1) / scale**2 - N_par / 2 * np.log(2*np.pi*scale**2)

def prior_transform(cube):
    # flat prior in [0, 1]^n
    return cube


run_times = []
z_estimates = []
z_errors = []

sampler = ultranest.ReactiveNestedSampler(
    param_names, 
    loglike, 
    prior_transform,
    log_dir="regular_sampling", 
    resume='overwrite', 
    vectorized=True
)

t1 = perf_counter()
result = sampler.run(min_num_live_points=N_live, frac_remain=frac_remain)
t2 = perf_counter()

sampler.print_results()

sampler.plot_trace()
plt.savefig('regular_sampling.png')
plt.close()

z_estimates.append(result['logz'])
z_errors.append(result['logzerr'])

run_times.append(t2-t1)

rng = np.random.default_rng(1)

for scale_factor in [0.5, 1, 2.]:
    
    simulated_posterior = rng.normal(
        loc=0.5*np.ones(N_par), 
        scale=scale*scale_factor, 
        size=(N_post, N_par)
    )

    aux_param_names, aux_loglike, aux_transform, vectorized = get_auxiliary_contbox_parameterization(
            param_names, 
            loglike, 
            prior_transform, 
            simulated_posterior, 
            np.ones(N_post) / N_post,
            vectorized=True,
        )
    
    sampler_aux = ultranest.ReactiveNestedSampler(
        aux_param_names, 
        aux_loglike, 
        aux_transform,
        log_dir=f"aux_sampling_{scale_factor}", 
        resume='overwrite', 
        vectorized=True
    )

    t3 = perf_counter()
    result_aux = sampler_aux.run(min_num_live_points=N_live, frac_remain=frac_remain)
    t4 = perf_counter()
    run_times.append(t4-t3)

    sampler_aux.print_results()
    sampler_aux.plot_trace()
    z_estimates.append(result_aux['logz'])
    z_errors.append(result_aux['logzerr'])


scenarios = [
    'Regular sampling',
    'Underestimated standard deviation',
    'Correct standard deviation',
    'Overestimated standard deviation',
]

print(f'Ultranest {ultranest.__version__}')

for i in range(4):
    print(f'Scenario: {scenarios[i]}')
    print(f't={run_times[i]:.2f}s')
    print(f'logZ={z_estimates[i]:.2f} +- {z_errors[i]:.2f}')
    print(f'error: {abs(z_estimates[i])/z_errors[i]:.2f} sigmas\n')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions