1- # Copyright (c) 2022 Mira Geoscience Ltd.
1+ # Copyright (c) 2023 Mira Geoscience Ltd.
22#
33# This file is part of param-sweeps.
44#
88from __future__ import annotations
99
1010import argparse
11+ import importlib
12+ import inspect
1113import itertools
1214import json
1315import os
14- import subprocess
1516import uuid
1617from dataclasses import dataclass
1718from inspect import signature
@@ -92,6 +93,10 @@ class SweepDriver:
9293
9394 def __init__ (self , params ):
9495 self .params : SweepParams = params
96+ self .workspace = params .geoh5
97+ self .working_directory = os .path .dirname (self .workspace .h5file )
98+ lookup = self .get_lookup ()
99+ self .write_files (lookup )
95100
96101 @staticmethod
97102 def uuid_from_params (params : tuple ) -> str :
@@ -104,75 +109,97 @@ def uuid_from_params(params: tuple) -> str:
104109 """
105110 return str (uuid .uuid5 (uuid .NAMESPACE_DNS , str (hash (params ))))
106111
107- def run (self , files_only = False ):
108- """Execute a sweep."""
112+ def get_lookup (self ):
113+ """Generate lookup table for sweep trials."""
114+
115+ lookup = {}
116+ sets = self .params .parameter_sets ()
117+ iterations = list (itertools .product (* sets .values ()))
118+ for iteration in iterations :
119+ param_uuid = SweepDriver .uuid_from_params (iteration )
120+ lookup [param_uuid ] = dict (zip (sets .keys (), iteration ))
121+ lookup [param_uuid ]["status" ] = "pending"
122+
123+ lookup = self .update_lookup (lookup , gather_first = True )
124+ return lookup
125+
126+ def update_lookup (self , lookup : dict , gather_first : bool = False ):
127+ """Updates lookup with new entries. Ensures any previous runs are incorporated."""
128+ lookup_path = os .path .join (self .working_directory , "lookup.json" )
129+ if os .path .exists (lookup_path ) and gather_first : # In case restarting
130+ with open (lookup_path , encoding = "utf8" ) as file :
131+ lookup .update (json .load (file ))
132+
133+ with open (lookup_path , "w" , encoding = "utf8" ) as file :
134+ json .dump (lookup , file , indent = 4 )
135+
136+ return lookup
137+
138+ def write_files (self , lookup ):
139+ """Write ui.geoh5 and ui.json files for sweep trials."""
109140
110141 ifile = InputFile .read_ui_json (self .params .worker_uijson )
111142 with ifile .data ["geoh5" ].open (mode = "r" ) as workspace :
112- sets = self .params .parameter_sets ()
113- iterations = list (itertools .product (* sets .values ()))
114- print (
115- f"Running parameter sweep for { len (iterations )} "
116- f"trials of the { ifile .data ['title' ]} driver."
117- )
118143
119- param_lookup = {}
120- for count , iteration in enumerate (iterations ):
121- param_uuid = SweepDriver .uuid_from_params (iteration )
122- filepath = os .path .join (
123- os .path .dirname (workspace .h5file ), f"{ param_uuid } .ui.geoh5"
124- )
125- param_lookup [param_uuid ] = dict (zip (sets .keys (), iteration ))
144+ for name , trial in lookup .items ():
126145
127- if os .path .exists (filepath ):
128- print (
129- f"{ count } : Skipping trial: { param_uuid } . "
130- f"Already computed and saved to file."
131- )
146+ if trial ["status" ] != "pending" :
132147 continue
133148
134- print (
135- f"{ count } : Running trial: { param_uuid } . "
136- f"Use lookup.json to map uuid to parameter set."
149+ filepath = os .path .join (
150+ os .path .dirname (workspace .h5file ), f"{ name } .ui.geoh5"
137151 )
138152 with Workspace (filepath ) as iter_workspace :
139153 ifile .data .update (
140- dict (param_lookup [param_uuid ], ** {"geoh5" : iter_workspace })
154+ dict (
155+ {key : val for key , val in trial .items () if key != "status" },
156+ ** {"geoh5" : iter_workspace },
157+ )
141158 )
142159 objects = [v for v in ifile .data .values () if hasattr (v , "uid" )]
143160 for obj in objects :
144161 if not isinstance (obj , Data ):
145162 obj .copy (parent = iter_workspace , copy_children = True )
146163
147- update_lookup (param_lookup , workspace )
148-
149- ifile .name = f"{ param_uuid } .ui.json"
164+ ifile .name = f"{ name } .ui.json"
150165 ifile .path = os .path .dirname (workspace .h5file )
151166 ifile .write_ui_json ()
167+ lookup [name ]["status" ] = "written"
152168
153- if not files_only :
154- call_worker_subprocess (ifile )
169+ _ = self .update_lookup (lookup )
155170
171+ def run (self ):
172+ """Execute a sweep."""
156173
157- def call_worker_subprocess (ifile : InputFile ):
158- """Runs the worker for the sweep parameters contained in 'ifile'."""
159- subprocess .run (
160- ["python" , "-m" , ifile .data ["run_command" ], ifile .path_name ],
161- check = True ,
162- )
174+ lookup_path = os .path .join (self .working_directory , "lookup.json" )
175+ with open (lookup_path , encoding = "utf8" ) as file :
176+ lookup = json .load (file )
163177
178+ for name , trial in lookup .items ():
179+ ifile = InputFile .read_ui_json (
180+ os .path .join (self .working_directory , f"{ name } .ui.json" )
181+ )
182+ status = trial .pop ("status" )
183+ if status != "complete" :
184+ lookup [name ]["status" ] = "processing"
185+ self .update_lookup (lookup )
186+ call_worker (ifile )
187+ lookup [name ]["status" ] = "complete"
188+ self .update_lookup (lookup )
164189
165- def update_lookup (lookup : dict , workspace : Workspace ):
166- """Updates lookup with new entries. Ensures any previous runs are incorporated."""
167- lookup_path = os .path .join (os .path .dirname (workspace .h5file ), "lookup.json" )
168- if os .path .exists (lookup_path ): # In case restarting
169- with open (lookup_path , encoding = "utf8" ) as file :
170- lookup .update (json .load (file ))
171190
172- with open ( lookup_path , "w" , encoding = "utf8" ) as file :
173- json . dump ( lookup , file , indent = 4 )
191+ def call_worker ( ifile : InputFile ) :
192+ """Runs the worker for the sweep parameters contained in 'ifile'."""
174193
175- return lookup
194+ run_cmd = ifile .data ["run_command" ]
195+ module = importlib .import_module (run_cmd )
196+ filt = (
197+ lambda member : inspect .isclass (member )
198+ and member .__module__ == run_cmd
199+ and hasattr (member , "run" )
200+ )
201+ driver = inspect .getmembers (module , filt )[0 ][1 ]
202+ driver .start (ifile .path_name )
176203
177204
178205def file_validation (filepath ):
@@ -188,14 +215,14 @@ def file_validation(filepath):
188215 raise OSError (f"File argument { filepath } must have extension 'ui.json'." )
189216
190217
191- def main (file_path , files_only = False ):
218+ def main (file_path ):
192219 """Run the program."""
193220
194221 file_validation (file_path )
195222 print ("Reading parameters and workspace..." )
196223 input_file = InputFile .read_ui_json (file_path )
197224 sweep_params = SweepParams .from_input_file (input_file )
198- SweepDriver (sweep_params ).run (files_only )
225+ SweepDriver (sweep_params ).run ()
199226
200227
201228if __name__ == "__main__" :
@@ -206,4 +233,4 @@ def main(file_path, files_only=False):
206233 parser .add_argument ("file" , help = "File with ui.json format." )
207234
208235 args = parser .parse_args ()
209- main (args .file )
236+ main (os . path . abspath ( args .file ) )
0 commit comments