Skip to content

Commit 37e4bab

Browse files
committed
fix ex09
1 parent c6c4a7b commit 37e4bab

3 files changed

Lines changed: 312 additions & 59 deletions

File tree

examples/example_08_QRS_known_cell.py

Lines changed: 68 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_vector_dimension(xtal):
5050
return sum(len(site.get_bounds()) for site in xtal.mol_sites)
5151

5252

53-
def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_dir="qrs_plots"):
53+
def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_dir="qrs_plots", time_cost_s=None):
5454
"""Save a plot of visited-structure ID vs energy for one QRS run."""
5555
if not energies:
5656
print(f"No energies collected for {code}; skipping plot.")
@@ -76,72 +76,81 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
7676
)
7777
ax.set_xlabel("Visited Structure ID")
7878
ax.set_ylabel("Energy (kcal/mol)")
79-
ax.set_title(f"{code}: visited structure ID vs energy")
79+
if time_cost_s is not None:
80+
ax.set_title(f"{code}: visited structure ID vs energy (time: {time_cost_s:.2f} s)")
81+
else:
82+
ax.set_title(f"{code}: visited structure ID vs energy")
8083
ax.grid(alpha=0.25, linestyle="--", linewidth=0.6)
8184
ax.legend(loc="best")
85+
ax.set_ylim(min(energies) - 5, max(energies) + 50)
8286
fig.tight_layout()
8387

8488
fig_path = os.path.join(out_dir, f"{code}_id_vs_energy.png")
8589
fig.savefig(fig_path, dpi=200)
8690
plt.close(fig)
8791
print(f"Saved plot: {fig_path}")
8892

89-
db = database("pyxtal/database/test.db")
90-
91-
csv_path = "qrs_results.csv"
92-
with open(csv_path, "w", newline="") as fcsv:
93-
writer = csv.writer(fcsv)
94-
writer.writerow(["code", "smiles", "sg", "vector_dim", "time_cost_s", "success_rate"])
95-
96-
for code in db.get_all_codes()[2:]:
97-
# ── Molecule ──────────────────────────────────────────────────────────────────
98-
row = db.get_row(code=code)
99-
ref_xtal = db.get_pyxtal(code=code)
100-
if ref_xtal.has_special_site(): ref_xtal = ref_xtal.to_subgroup()
101-
vector_dim = get_vector_dimension(ref_xtal)
102-
ref_pmg = ref_xtal.to_pymatgen()
103-
ref_pmg.remove_species(["H"]) # ignore H for matching since positions are less certain
104-
print(ref_xtal)
105-
106-
sites = [[] for _ in range(len(ref_xtal.numMols))]
107-
for site in ref_xtal.mol_sites:
108-
sites[site.type].append(site.wp.get_label())
109-
110-
# ── QRS setup ─────────────────────────────────────────────────────────────────
111-
qrs = QRS(
112-
smiles = row.mol_smi, # molecule as SMILES string
113-
workdir = row.csd_code, # working directory for this QRS run
114-
sg = ref_xtal.group.hall_number, # space group number (P2_1/c = 81)
115-
tag = row.csd_code.lower(), # tag for output files
116-
use_hall = True, # interpret sg as a Hall number
117-
lattice = ref_xtal.lattice, # fix the cell; only WP positions are sampled
118-
N_gen = 10, # number of QRS generations
119-
N_pop = 50, # structures per generation
120-
N_cpu = 1,
121-
cif = "all.cif", # save all relaxed structures
122-
skip_mlp = True, # no machine-learning potential relaxation
123-
verbose = False,
124-
sites = sites,
125-
delta_length = 1.5, # grid spacing for fractional coords (Å)
126-
delta_angle = 60.0, # grid spacing for Euler/torsion angles (°)
127-
)
128-
129-
# ── Run and check for match ───────────────────────────────────────────────────
130-
t0 = perf_counter()
131-
success_rate = qrs.run(ref_pmg=ref_pmg)
132-
time_cost_s = perf_counter() - t0
133-
134-
if success_rate is not None and success_rate > 0:
135-
print(f"\nMatch found! Success rate: {success_rate}%")
136-
else:
137-
print("\nNo match found within the given generations/population.")
138-
print(f"Time cost: {time_cost_s:.2f} s")
139-
140-
match_ids, match_energies = get_match_points(qrs)
141-
plot_id_vs_energy(code, qrs.engs, match_ids, match_energies)
93+
if __name__ == "__main__":
94+
db = database("pyxtal/database/test.db")
95+
os.makedirs("Tests", exist_ok=True)
96+
csv_path = "Tests/qrs_results.csv"
14297

143-
with open(csv_path, "a", newline="") as fcsv:
98+
with open(csv_path, "w", newline="") as fcsv:
14499
writer = csv.writer(fcsv)
145-
writer.writerow([code, row.mol_smi, ref_xtal.group.number, vector_dim,
146-
f"{time_cost_s:.2f}",
147-
f"{success_rate:.4f}" if success_rate is not None else ""])
100+
writer.writerow(["code", "smiles", "sg", "vector_dim", "time_cost_s", "success_rate"])
101+
102+
for code in db.get_all_codes()[4:]:
103+
# ── Molecule ──────────────────────────────────────────────────────────────────
104+
row = db.get_row(code=code)
105+
ref_xtal = db.get_pyxtal(code=code)
106+
if ref_xtal.has_special_site(): ref_xtal = ref_xtal.to_subgroup()
107+
vector_dim = get_vector_dimension(ref_xtal)
108+
ref_pmg = ref_xtal.to_pymatgen()
109+
ref_pmg.remove_species(["H"]) # ignore H for matching since positions are less certain
110+
print(ref_xtal)
111+
workdir = os.path.join("Tests", row.csd_code)
112+
113+
sites = [[] for _ in range(len(ref_xtal.numMols))]
114+
for site in ref_xtal.mol_sites:
115+
sites[site.type].append(site.wp.get_label())
116+
117+
# ── QRS setup ─────────────────────────────────────────────────────────────────
118+
if os.path.exists(workdir + '/parameters.xml'):
119+
os.remove(workdir + '/parameters.xml')
120+
qrs = QRS(
121+
smiles = row.mol_smi, # molecule as SMILES string
122+
workdir = workdir, # working directory for this QRS run
123+
sg = ref_xtal.group.hall_number, # space group number (P2_1/c = 81)
124+
tag = row.csd_code.lower(), # tag for output files
125+
use_hall = True, # interpret sg as a Hall number
126+
lattice = ref_xtal.lattice, # fix the cell; only WP positions are sampled
127+
N_gen = 10, # number of QRS generations
128+
N_pop = 50, # structures per generation
129+
N_cpu = 2,
130+
cif = "all.cif", # save all relaxed structures
131+
skip_mlp = True, # no machine-learning potential relaxation
132+
verbose = False,
133+
sites = sites,
134+
delta_length = 1.5, # grid spacing for fractional coords (Å)
135+
delta_angle = 60.0, # grid spacing for Euler/torsion angles (°)
136+
)
137+
138+
# ── Run and check for match ───────────────────────────────────────────────────
139+
t0 = perf_counter()
140+
success_rate = qrs.run(ref_pmg=ref_pmg)
141+
time_cost_s = perf_counter() - t0
142+
143+
if success_rate is not None and success_rate > 0:
144+
print(f"\nMatch found! Success rate: {success_rate}%")
145+
else:
146+
print("\nNo match found within the given generations/population.")
147+
print(f"Time cost: {time_cost_s:.2f} s")
148+
149+
match_ids, match_energies = get_match_points(qrs)
150+
plot_id_vs_energy(code, qrs.engs, match_ids, match_energies, time_cost_s=time_cost_s)
151+
152+
with open(csv_path, "a", newline="") as fcsv:
153+
writer = csv.writer(fcsv)
154+
writer.writerow([code, row.mol_smi, ref_xtal.group.number, vector_dim,
155+
f"{time_cost_s:.2f}",
156+
f"{success_rate:.4f}" if success_rate is not None else ""])

examples/example_09_QRS.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""
2+
Example: Run QRS with pregenerated molecular conformers.
3+
4+
This script iterates over entries in pyxtal/database/test.db, precomputes a pool
5+
of pyxtal_molecule conformers via generate_molecules, and passes them to QRS.
6+
With the lattice fixed and conformers pregenerated, QRS samples only Wyckoff-site
7+
fractional coordinates and molecular orientations for each trial crystal.
8+
9+
Results are appended to Tests/qrs_results_pregen_mols.csv.
10+
"""
11+
12+
import csv
13+
import os
14+
from time import perf_counter
15+
16+
import matplotlib.pyplot as plt
17+
18+
from pyxtal.db import database
19+
from pyxtal.molecule import generate_molecules
20+
from pyxtal.optimize import QRS
21+
22+
23+
def get_vector_dimension(xtal):
24+
"""Return the total dimension of sampling vectors for this crystal."""
25+
return sum(len(site.get_bounds()) for site in xtal.mol_sites)
26+
27+
28+
def build_sites_from_reference(ref_xtal):
29+
"""Build QRS site labels grouped by molecule type from a reference crystal."""
30+
sites = [[] for _ in range(len(ref_xtal.numMols))]
31+
for site in ref_xtal.mol_sites:
32+
sites[site.type].append(site.wp.get_label())
33+
return sites
34+
35+
36+
def filter_similar_molecules(mols, rmsd_tol=0.5):
37+
"""Remove near-duplicate pyxtal_molecule conformers using pairwise RMSD."""
38+
unique_mols = []
39+
for mol in mols:
40+
is_duplicate = False
41+
for ref_mol in unique_mols:
42+
rmsd, _ = mol.get_rmsd2(mol.mol.cart_coords, ref_mol.mol.cart_coords)
43+
if rmsd < rmsd_tol:
44+
is_duplicate = True
45+
break
46+
if not is_duplicate:
47+
unique_mols.append(mol)
48+
return unique_mols
49+
50+
51+
def get_match_points(qrs):
52+
"""Map QRS match (gen, pop) records to visited-structure IDs and energies."""
53+
if not hasattr(qrs, "matches") or not hasattr(qrs, "stats"):
54+
return [], []
55+
56+
match_keys = set()
57+
for match in qrs.matches:
58+
if len(match) >= 2:
59+
match_keys.add((int(match[0]), int(match[1])))
60+
61+
if not match_keys:
62+
return [], []
63+
64+
match_ids = []
65+
match_energies = []
66+
visited_id = 0
67+
for gen in range(qrs.N_gen):
68+
for pop in range(qrs.N_pop):
69+
energy = float(qrs.stats[gen, pop, 0])
70+
if energy < qrs.E_max:
71+
visited_id += 1
72+
if (gen, pop) in match_keys:
73+
match_ids.append(visited_id)
74+
match_energies.append(energy)
75+
76+
return match_ids, match_energies
77+
78+
79+
def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_dir="qrs_plots", time_cost_s=None):
80+
"""Save a plot of visited-structure ID vs energy for one QRS run."""
81+
if not energies:
82+
print(f"No energies collected for {code}; skipping plot.")
83+
return
84+
85+
os.makedirs(out_dir, exist_ok=True)
86+
ids = list(range(1, len(energies) + 1))
87+
88+
fig, ax = plt.subplots(figsize=(7, 4.5))
89+
ax.scatter(ids, energies, s=20, alpha=0.8, label="Visited")
90+
ax.plot(ids, energies, linewidth=0.8, alpha=0.6)
91+
if match_ids and match_energies:
92+
ax.scatter(
93+
match_ids,
94+
match_energies,
95+
s=70,
96+
marker="*",
97+
c="crimson",
98+
edgecolors="black",
99+
linewidths=0.6,
100+
zorder=3,
101+
label="Match",
102+
)
103+
ax.set_xlabel("Visited Structure ID")
104+
ax.set_ylabel("Energy (kcal/mol)")
105+
if time_cost_s is not None:
106+
ax.set_title(f"{code}: visited structure ID vs energy (time: {time_cost_s:.2f} s)")
107+
else:
108+
ax.set_title(f"{code}: visited structure ID vs energy")
109+
ax.grid(alpha=0.25, linestyle="--", linewidth=0.6)
110+
ax.legend(loc="best")
111+
112+
ymin = min(energies)
113+
ymax = max(energies)
114+
if ymin == ymax:
115+
margin = max(abs(ymin) * 0.05, 1.0)
116+
ax.set_ylim(ymin - margin, ymax + margin)
117+
else:
118+
y_max = min(ymin + 50, ymax)
119+
ax.set_ylim(ymin - 1, y_max)
120+
fig.tight_layout()
121+
122+
fig_path = os.path.join(out_dir, f"{code}_id_vs_energy.png")
123+
fig.savefig(fig_path, dpi=200)
124+
plt.close(fig)
125+
print(f"Saved plot: {fig_path}")
126+
127+
128+
if __name__ == "__main__":
129+
db = database("pyxtal/database/test.db")
130+
os.makedirs("Tests", exist_ok=True)
131+
csv_path = "Tests/qrs_results_pregen_mols.csv"
132+
133+
with open(csv_path, "w", newline="") as fcsv:
134+
writer = csv.writer(fcsv)
135+
writer.writerow(
136+
[
137+
"code",
138+
"smiles",
139+
"sg",
140+
"vector_dim",
141+
"n_pregenerated_confs",
142+
"time_cost_s",
143+
"success_rate",
144+
]
145+
)
146+
147+
for code in db.get_all_codes()[25:50]:
148+
#if code not in ['XATMOV']: continue
149+
row = db.get_row(code=code)
150+
ref_xtal = db.get_pyxtal(code=code)
151+
if ref_xtal.has_special_site():
152+
ref_xtal = ref_xtal.to_subgroup()
153+
154+
vector_dim = get_vector_dimension(ref_xtal)
155+
ref_pmg = ref_xtal.to_pymatgen()
156+
ref_pmg.remove_species(["H"]) # Ignore H for matching robustness.
157+
158+
print(f"\n=== {code} ===")
159+
print(ref_xtal)
160+
161+
workdir = os.path.join("Tests", row.csd_code)
162+
os.makedirs(workdir, exist_ok=True)
163+
sites = build_sites_from_reference(ref_xtal)
164+
165+
# Pregenerate molecular conformers that are compatible with the target WP.
166+
# We provide the resulting conformer pool to QRS so torsions are fixed per
167+
# sampled molecule choice and QRS only samples WP xyz + orientation DOFs.
168+
target_wps = [site.wp for site in ref_xtal.mol_sites]
169+
p_mols = generate_molecules(
170+
row.mol_smi,
171+
wps=target_wps,
172+
N_iter=8,
173+
N_conf=50,
174+
tol=0.5,
175+
)
176+
if len(p_mols) == 0:
177+
print("No valid pregenerated conformers; skipping this entry.")
178+
continue
179+
p_mols = filter_similar_molecules(p_mols, rmsd_tol=0.5)
180+
print(f"Unique conformers after filtering: {len(p_mols)}")
181+
if len(p_mols) == 0:
182+
print("All pregenerated conformers were filtered out; skipping this entry.")
183+
continue
184+
print(f"Pregenerated conformers: {len(p_mols)}")
185+
186+
param_xml = os.path.join(workdir, "parameters.xml")
187+
if os.path.exists(param_xml):
188+
os.remove(param_xml)
189+
qrs = QRS(
190+
smiles=row.mol_smi,
191+
workdir=workdir,
192+
sg=ref_xtal.group.hall_number,
193+
tag=row.csd_code.lower(),
194+
use_hall=True,
195+
lattice=ref_xtal.lattice, # Fixed cell.
196+
composition = [int(a) for a in ref_xtal.get_zprime()],
197+
molecules=[p_mols], # One molecular component with many conformers.
198+
sites=sites,
199+
N_gen=20,
200+
N_pop=50,
201+
N_cpu=2,
202+
cif="all.cif",
203+
skip_mlp=True,
204+
verbose=False,
205+
delta_length=1.5,
206+
delta_angle=45.0,
207+
)
208+
209+
t0 = perf_counter()
210+
success_rate = qrs.run(ref_pmg=ref_pmg)
211+
time_cost_s = perf_counter() - t0
212+
213+
if success_rate is not None and success_rate > 0:
214+
print(f"Match found! Success rate: {success_rate}%")
215+
else:
216+
print("No match found within the given generations/population.")
217+
print(f"Time cost: {time_cost_s:.2f} s")
218+
match_ids, match_energies = get_match_points(qrs)
219+
plot_id_vs_energy(
220+
code,
221+
qrs.engs,
222+
match_ids=match_ids,
223+
match_energies=match_energies,
224+
out_dir="Tests/qrs_plots",
225+
time_cost_s=time_cost_s,
226+
)
227+
228+
with open(csv_path, "a", newline="") as fcsv:
229+
writer = csv.writer(fcsv)
230+
writer.writerow(
231+
[
232+
code,
233+
row.mol_smi,
234+
ref_xtal.group.number,
235+
vector_dim,
236+
len(p_mols),
237+
f"{time_cost_s:.2f}",
238+
f"{success_rate:.4f}" if success_rate is not None else "",
239+
]
240+
)
241+
242+
print(f"\nSaved summary CSV: {csv_path}")

0 commit comments

Comments
 (0)