1313import os
1414from time import perf_counter
1515
16+ import matplotlib .pyplot as plt
17+
1618from pyxtal .db import database
1719from pyxtal .molecule import generate_molecules
1820from pyxtal .optimize import QRS
@@ -46,6 +48,83 @@ def filter_similar_molecules(mols, rmsd_tol=0.5):
4648 return unique_mols
4749
4850
51+ def get_match_points (qrs ):
52+ """Map QRS match (gen, pop) records to visited-structure IDs and energies."""
53+ if not hasattr (qrs , "matches" ) or not hasattr (qrs , "stats" ):
54+ return [], []
55+
56+ match_keys = set ()
57+ for match in qrs .matches :
58+ if len (match ) >= 2 :
59+ match_keys .add ((int (match [0 ]), int (match [1 ])))
60+
61+ if not match_keys :
62+ return [], []
63+
64+ match_ids = []
65+ match_energies = []
66+ visited_id = 0
67+ for gen in range (qrs .N_gen ):
68+ for pop in range (qrs .N_pop ):
69+ energy = float (qrs .stats [gen , pop , 0 ])
70+ if energy < qrs .E_max :
71+ visited_id += 1
72+ if (gen , pop ) in match_keys :
73+ match_ids .append (visited_id )
74+ match_energies .append (energy )
75+
76+ return match_ids , match_energies
77+
78+
79+ def plot_id_vs_energy (code , energies , match_ids = None , match_energies = None , out_dir = "qrs_plots" , time_cost_s = None ):
80+ """Save a plot of visited-structure ID vs energy for one QRS run."""
81+ if not energies :
82+ print (f"No energies collected for { code } ; skipping plot." )
83+ return
84+
85+ os .makedirs (out_dir , exist_ok = True )
86+ ids = list (range (1 , len (energies ) + 1 ))
87+
88+ fig , ax = plt .subplots (figsize = (7 , 4.5 ))
89+ ax .scatter (ids , energies , s = 20 , alpha = 0.8 , label = "Visited" )
90+ ax .plot (ids , energies , linewidth = 0.8 , alpha = 0.6 )
91+ if match_ids and match_energies :
92+ ax .scatter (
93+ match_ids ,
94+ match_energies ,
95+ s = 70 ,
96+ marker = "*" ,
97+ c = "crimson" ,
98+ edgecolors = "black" ,
99+ linewidths = 0.6 ,
100+ zorder = 3 ,
101+ label = "Match" ,
102+ )
103+ ax .set_xlabel ("Visited Structure ID" )
104+ ax .set_ylabel ("Energy (kcal/mol)" )
105+ if time_cost_s is not None :
106+ ax .set_title (f"{ code } : visited structure ID vs energy (time: { time_cost_s :.2f} s)" )
107+ else :
108+ ax .set_title (f"{ code } : visited structure ID vs energy" )
109+ ax .grid (alpha = 0.25 , linestyle = "--" , linewidth = 0.6 )
110+ ax .legend (loc = "best" )
111+
112+ ymin = min (energies )
113+ ymax = max (energies )
114+ if ymin == ymax :
115+ margin = max (abs (ymin ) * 0.05 , 1.0 )
116+ ax .set_ylim (ymin - margin , ymax + margin )
117+ else :
118+ y_max = min (ymin + 50 , ymax )
119+ ax .set_ylim (ymin - 1 , y_max )
120+ fig .tight_layout ()
121+
122+ fig_path = os .path .join (out_dir , f"{ code } _id_vs_energy.png" )
123+ fig .savefig (fig_path , dpi = 200 )
124+ plt .close (fig )
125+ print (f"Saved plot: { fig_path } " )
126+
127+
49128if __name__ == "__main__" :
50129 db = database ("pyxtal/database/test.db" )
51130 os .makedirs ("Tests" , exist_ok = True )
@@ -65,7 +144,8 @@ def filter_similar_molecules(mols, rmsd_tol=0.5):
65144 ]
66145 )
67146
68- for code in db .get_all_codes ()[5 :8 ]:
147+ for code in db .get_all_codes ():
148+ #if code not in ['KONTIQ']: continue
69149 row = db .get_row (code = code )
70150 ref_xtal = db .get_pyxtal (code = code )
71151 if ref_xtal .has_special_site ():
@@ -82,39 +162,62 @@ def filter_similar_molecules(mols, rmsd_tol=0.5):
82162 os .makedirs (workdir , exist_ok = True )
83163 sites = build_sites_from_reference (ref_xtal )
84164
85- # Pregenerate molecular conformers that are compatible with the target WP.
86- # We provide the resulting conformer pool to QRS so torsions are fixed per
87- # sampled molecule choice and QRS only samples WP xyz + orientation DOFs.
88- target_wps = [site .wp for site in ref_xtal .mol_sites ]
89- p_mols = generate_molecules (
90- row .mol_smi ,
91- wps = target_wps ,
92- N_iter = 8 ,
93- N_conf = 50 ,
94- tol = 0.5 ,
95- )
96- if len (p_mols ) == 0 :
97- print ("No valid pregenerated conformers; skipping this entry." )
165+ # Pregenerate per-component conformer pools aligned with site.type.
166+ # QRS expects one pool per molecular component in row.mol_smi.split('.').
167+ smiles_parts = row .mol_smi .split ("." )
168+ n_types = len (ref_xtal .numMols )
169+ if len (smiles_parts ) != n_types :
170+ print (
171+ f"SMILES/component mismatch for { code } : "
172+ f"{ len (smiles_parts )} smiles parts vs { n_types } crystal components; skipping."
173+ )
98174 continue
99- p_mols = filter_similar_molecules (p_mols , rmsd_tol = 0.5 )
100- print (f"Unique conformers after filtering: { len (p_mols )} " )
101- if len (p_mols ) == 0 :
102- print ("All pregenerated conformers were filtered out; skipping this entry." )
175+
176+ type_wps = [[] for _ in range (n_types )]
177+ for site in ref_xtal .mol_sites :
178+ type_wps [site .type ].append (site .wp )
179+
180+ molecules = []
181+ n_pregen_total = 0
182+ for type_idx , smi in enumerate (smiles_parts ):
183+ p_mols = generate_molecules (
184+ smi ,
185+ wps = type_wps [type_idx ],
186+ N_iter = 8 ,
187+ N_conf = 50 ,
188+ tol = 0.5 ,
189+ )
190+ if len (p_mols ) == 0 :
191+ print (f"No valid pregenerated conformers for component { type_idx } ({ smi } ); skipping." )
192+ molecules = None
193+ break
194+
195+ p_mols = filter_similar_molecules (p_mols , rmsd_tol = 0.5 )
196+ print (f"Component { type_idx } ({ smi } ) unique conformers: { len (p_mols )} " )
197+ if len (p_mols ) == 0 :
198+ print (f"All conformers filtered out for component { type_idx } ({ smi } ); skipping." )
199+ molecules = None
200+ break
201+
202+ molecules .append (p_mols )
203+ n_pregen_total += len (p_mols )
204+
205+ if molecules is None :
103206 continue
104- print (f"Pregenerated conformers: { len ( p_mols ) } " )
207+ print (f"Total pregenerated conformers across components : { n_pregen_total } " )
105208
106209 param_xml = os .path .join (workdir , "parameters.xml" )
107210 if os .path .exists (param_xml ):
108211 os .remove (param_xml )
109-
110212 qrs = QRS (
111213 smiles = row .mol_smi ,
112214 workdir = workdir ,
113215 sg = ref_xtal .group .hall_number ,
114216 tag = row .csd_code .lower (),
115217 use_hall = True ,
116218 lattice = ref_xtal .lattice , # Fixed cell.
117- molecules = [p_mols ], # One molecular component with many conformers.
219+ composition = [int (a ) for a in ref_xtal .get_zprime ()],
220+ molecules = molecules ,
118221 sites = sites ,
119222 N_gen = 20 ,
120223 N_pop = 50 ,
@@ -135,6 +238,15 @@ def filter_similar_molecules(mols, rmsd_tol=0.5):
135238 else :
136239 print ("No match found within the given generations/population." )
137240 print (f"Time cost: { time_cost_s :.2f} s" )
241+ match_ids , match_energies = get_match_points (qrs )
242+ plot_id_vs_energy (
243+ code ,
244+ qrs .engs ,
245+ match_ids = match_ids ,
246+ match_energies = match_energies ,
247+ out_dir = "Tests/qrs_plots" ,
248+ time_cost_s = time_cost_s ,
249+ )
138250
139251 with open (csv_path , "a" , newline = "" ) as fcsv :
140252 writer = csv .writer (fcsv )
@@ -144,7 +256,7 @@ def filter_similar_molecules(mols, rmsd_tol=0.5):
144256 row .mol_smi ,
145257 ref_xtal .group .number ,
146258 vector_dim ,
147- len ( p_mols ) ,
259+ n_pregen_total ,
148260 f"{ time_cost_s :.2f} " ,
149261 f"{ success_rate :.4f} " if success_rate is not None else "" ,
150262 ]
0 commit comments