-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgmm_worker.py
More file actions
142 lines (120 loc) · 5.49 KB
/
Copy pathgmm_worker.py
File metadata and controls
142 lines (120 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python
"""Hyperpipe-compatible worker: evaluates GMM lnL for a sub-range of the parameter grid.
Follows the exact worker API used by create_eos_posterior_pipeline
(see MonteCarloMarginalizeCode/Code/demo/hyperpipe/example_gaussian2.py).
The full per-grid-point log-likelihood is:
lnL = lnL_events + N_obs * ln(R) - R * VT_eff
where:
lnL_events = sum_k log p(m_k | GMM params) (from GMMArchive)
VT_eff = E_{p(m|params)}[VT(m)] (from GMMArchive, MC integral)
R = 10^log_R (grid parameter)
N_obs = number of observations
Typical call from a DAG node:
python gmm_worker.py \\
--using-eos file:/path/to/grid.dat \\
--using-eos-index 0 --n-events-to-analyze 50 \\
--observations-file /path/to/observations.dat \\
--archive /path/to/gmm_archive \\
--vt-model uniform \\
--outdir gmm_out \\
--fname-output-integral MARG-0-0 \\
--conforming-output-name
"""
import argparse
import os
import sys
from pathlib import Path
import numpy as np
# ---------------------------------------------------------------------------
# Argument parsing — must match the hyperpipe worker API exactly
# ---------------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--fname", type=str, help="Dummy argument required by API")
parser.add_argument("--using-eos", type=str, required=True,
help="Grid file, prefixed with 'file:'")
parser.add_argument("--using-eos-index", type=int, default=None,
help="Single-index mode start row")
parser.add_argument("--n-events-to-analyze", type=int, default=1,
help="Chunk size when using --using-eos-index")
parser.add_argument("--eos_start_index", type=int, default=None)
parser.add_argument("--eos_end_index", type=int, default=None)
parser.add_argument("--outdir", type=str, default=".")
parser.add_argument("--fname-output-integral", type=str, required=True)
parser.add_argument("--fname-output-samples", type=str, default=None,
help="Dummy argument required by API")
parser.add_argument("--conforming-output-name", action="store_true",
help="Append +annotation.dat to output filename")
# Application-specific args
parser.add_argument("--observations-file", type=str, required=True,
help="Path to observations.dat (two-column m1 m2)")
parser.add_argument("--archive", type=str, required=True,
help="Path to GMMArchive directory")
parser.add_argument("--vt-model", type=str, default="uniform",
choices=["uniform", "chirp_mass"],
help="Selection function model")
opts = parser.parse_args()
# Resolve index range
if opts.using_eos_index is not None:
opts.eos_start_index = opts.using_eos_index
opts.eos_end_index = opts.using_eos_index + opts.n_events_to_analyze
if opts.eos_start_index is None or opts.eos_end_index is None:
print("ERROR: must specify either --using-eos-index or --eos_start_index/--eos_end_index",
file=sys.stderr)
sys.exit(1)
# ---------------------------------------------------------------------------
# Load inputs
# ---------------------------------------------------------------------------
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from gmm_archive import GMMArchive
fname_grid = opts.using_eos.replace("file:", "")
with open(fname_grid) as f:
header = f.readline().strip()
all_param_names = header.lstrip("#").split()[2:] # drop lnL, sigma_lnL
grid = np.genfromtxt(fname_grid, dtype=str)
data = np.loadtxt(opts.observations_file, comments="#")
N_obs = len(data)
archive = GMMArchive(opts.archive)
# Identify which column holds log_R (may be absent for backwards-compat)
gmm_param_names = [n for n in all_param_names if n != "log_R"]
has_log_R = "log_R" in all_param_names
log_R_col = all_param_names.index("log_R") if has_log_R else None
# ---------------------------------------------------------------------------
# Evaluate lnL for each row in [start, end)
# ---------------------------------------------------------------------------
rows = []
for i in range(opts.eos_start_index, opts.eos_end_index):
row = grid[i]
# Build GMM params dict (excludes log_R)
gmm_params = {
name: float(row[2 + j])
for j, name in enumerate(all_param_names)
if name != "log_R"
}
result = archive.get_or_compute(gmm_params, data, vt_model=opts.vt_model)
lnL_events = result["lnL_events"]
VT_eff = result["VT_eff"]
if has_log_R:
log_R = float(row[2 + log_R_col])
R = 10.0 ** log_R
total_lnL = lnL_events + N_obs * np.log(R) - R * VT_eff
else:
# No rate parameter: plain per-event likelihood (backwards-compatible)
total_lnL = lnL_events
new_row = list(row)
new_row[0] = f"{total_lnL:.6f}"
new_row[1] = "0.001"
rows.append(new_row)
# ---------------------------------------------------------------------------
# Write output in standard RIFT format
# ---------------------------------------------------------------------------
Path(opts.outdir).mkdir(parents=True, exist_ok=True)
postfix = "+annotation.dat" if opts.conforming_output_name else ""
out_path = os.path.join(opts.outdir, opts.fname_output_integral + postfix)
header_line = "lnL sigma_lnL " + " ".join(all_param_names)
np.savetxt(
out_path,
rows,
fmt="%10s",
header=header_line,
)
print(f"Wrote {len(rows)} rows to {out_path}")