Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions simpeg/directives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
SaveLogFilesGeoH5,
SaveLPModelGroup,
SaveModelGeoH5,
SavePGIModel,
SavePropertyGroup,
SaveSensitivityGeoH5,
)
Expand Down
19 changes: 14 additions & 5 deletions simpeg/directives/_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,10 @@ def endIter(self):
if not isinstance(reg, Sparse):
continue

for obj in reg.objfcts:
for multi, obj in reg:
if multi == 0:
continue

obj.irls_threshold /= self.irls_cooling_factor

self.metrics.irls_iteration_count += 1
Expand Down Expand Up @@ -327,10 +330,16 @@ def start_irls(self):
if not isinstance(reg, Sparse):
continue

for obj in reg.objfcts:
threshold = np.percentile(
np.abs(obj.mapping * obj._delta_m(self.invProb.model)),
self.percentile,
for multi, obj in reg:
if multi == 0:
continue

threshold = (
np.percentile(
np.abs(obj.mapping * obj._delta_m(self.invProb.model)),
self.percentile,
)
+ 1e-16
)
if isinstance(obj, SmoothnessFirstOrder):
threshold /= reg.regularization_mesh.base_length
Expand Down
72 changes: 72 additions & 0 deletions simpeg/directives/_save_geoh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

import numpy as np
from scipy.sparse import csc_matrix, csr_matrix
from simpeg.regularization import PGIsmallness

from .directives import InversionDirective
from simpeg.maps import IdentityMap

from geoh5py.data import NumericData
from geoh5py.data.data_type import ReferencedValueMapType
from geoh5py.groups.property_group import GroupTypeEnum
from geoh5py.groups import UIJsonGroup
from geoh5py.objects import ObjectBase
Expand Down Expand Up @@ -503,3 +506,72 @@ def get_names(
base_name = "LP models"

return channel_name, base_name


class SavePGIModel(SaveArrayGeoH5):
"""
Save the model as a property group in the geoh5 file
"""

def __init__(
self,
h5_object: ObjectBase,
pgi_regularization: PGIsmallness,
unit_map: dict,
physical_properties: list[str],
reference_type: ReferencedValueMapType | None = None,
**kwargs,
):
self.pgi_regularization = pgi_regularization
self.unit_map: dict = unit_map
self.reference_type = reference_type
self.physical_properties = physical_properties
super().__init__(h5_object, **kwargs)

def get_values(self, values: list[np.ndarray] | None):

if values is None:
values = self.invProb.model

modellist = self.pgi_regularization.wiresmap * values
model = np.c_[
[a * b for a, b in zip(self.pgi_regularization.maplist, modellist)]
].T
membership = self.pgi_regularization.gmm._estimate_log_prob(model).argmax(
axis=1
)
return membership

def write(self, iteration: int, values: list[np.ndarray] | None = None):
"""
Method to write the reference model with data map.
"""
petro_model = self.get_values(values)
petro_model = self.apply_transformations(petro_model).flatten()
channel_name, base_name = self.get_names("petrophysics", "", iteration)
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
h5_object = w_s.get_entity(self.h5_object)[0]
data = h5_object.add_data(
{
channel_name: {
"association": self.association,
"values": petro_model,
"type": "referenced",
}
}
)

if self.reference_type is not None:
data.entity_type.value_map = self.reference_type.value_map
data.entity_type.color_map = self.reference_type.color_map

# TODO: Add the means of the transformed models
# means = self.pgi_regularization.gmm.means_
# for ii, phys_prop in enumerate(self.physical_properties):
# data.add_data_map(
# f"Mean {phys_prop}",
# {
# ind: f"{mean:.3e}"
# for ind, mean in zip(self.unit_map, means[:, ii])
# },
# )
3 changes: 2 additions & 1 deletion simpeg/directives/pgi_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def endIter(self):
self.pgi_reg.gmm = clfupdate
membership = self.pgi_reg.gmm.predict(model)

if self.fixed_membership is not None:
if clfupdate.fixed_membership is not None:
self.fixed_membership = clfupdate.fixed_membership
membership[self.fixed_membership[:, 0]] = self.fixed_membership[:, 1]

mref = mkvc(self.pgi_reg.gmm.means_[membership])
Expand Down