Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions simpeg/directives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
SaveLogFilesGeoH5,
SaveLPModelGroup,
SaveModelGeoH5,
SavePGIModel,
SavePropertyGroup,
SaveSensitivityGeoH5,
)
Expand Down
19 changes: 14 additions & 5 deletions simpeg/directives/_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,10 @@ def endIter(self):
if not isinstance(reg, Sparse):
continue

for obj in reg.objfcts:
for multi, obj in reg:
if multi == 0:
continue

obj.irls_threshold /= self.irls_cooling_factor

self.metrics.irls_iteration_count += 1
Expand Down Expand Up @@ -327,10 +330,16 @@ def start_irls(self):
if not isinstance(reg, Sparse):
continue

for obj in reg.objfcts:
threshold = np.percentile(
np.abs(obj.mapping * obj._delta_m(self.invProb.model)),
self.percentile,
for multi, obj in reg:
if multi == 0:
continue

threshold = (
np.percentile(
np.abs(obj.mapping * obj._delta_m(self.invProb.model)),
self.percentile,
)
+ 1e-16
)
if isinstance(obj, SmoothnessFirstOrder):
threshold /= reg.regularization_mesh.base_length
Expand Down
68 changes: 68 additions & 0 deletions simpeg/directives/_save_geoh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

import numpy as np
from scipy.sparse import csc_matrix, csr_matrix
from simpeg.regularization import PGIsmallness

from .directives import InversionDirective
from simpeg.maps import IdentityMap

from geoh5py.data import NumericData
from geoh5py.data.data_type import ReferencedValueMapType
from geoh5py.groups.property_group import GroupTypeEnum
from geoh5py.groups import UIJsonGroup
from geoh5py.objects import ObjectBase
Expand Down Expand Up @@ -503,3 +506,68 @@ def get_names(
base_name = "LP models"

return channel_name, base_name


class SavePGIModel(SaveArrayGeoH5):
"""
Save the model as a property group in the geoh5 file
"""

def __init__(
self,
h5_object,
Comment thread
domfournier marked this conversation as resolved.
Outdated
pgi_reg: PGIsmallness,
Comment thread
domfournier marked this conversation as resolved.
Outdated
unit_map: dict,
physical_properties: list[str],
reference_type: ReferencedValueMapType | None = None,
**kwargs,
):
self.pgi_reg = pgi_reg
Comment thread
domfournier marked this conversation as resolved.
Outdated
self.unit_map: dict = unit_map
self.reference_type = reference_type
self.physical_properties = physical_properties
super().__init__(h5_object, **kwargs)

def get_values(self, values: list[np.ndarray] | None):

if values is None:
values = self.invProb.model

modellist = self.pgi_reg.wiresmap * values
model = np.c_[[a * b for a, b in zip(self.pgi_reg.maplist, modellist)]].T
Comment thread
benk-mira marked this conversation as resolved.
Outdated
membership = self.pgi_reg.gmm._estimate_log_prob(model).argmax(axis=1)
return membership

def write(self, iteration: int, values: list[np.ndarray] = None):
Comment thread
domfournier marked this conversation as resolved.
Outdated
"""
Method to write the reference model with data map.
"""
petro_model = self.get_values(values)
petro_model = self.apply_transformations(petro_model).flatten()
channel_name, base_name = self.get_names("petrophysics", "", iteration)
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
h5_object = w_s.get_entity(self.h5_object)[0]
data = h5_object.add_data(
{
channel_name: {
"association": self.association,
"values": petro_model,
"type": "REFERENCED",
Comment thread
domfournier marked this conversation as resolved.
Outdated
}
}
)

if self.reference_type is not None:
data.entity_type.value_map = self.reference_type.value_map
data.entity_type.color_map = self.reference_type.color_map

# TODO: Add the means of the transformed models
# means = self.pgi_reg.gmm.means_
# for ii, phys_prop in enumerate(self.physical_properties):
# data.add_data_map(
# f"Mean {phys_prop}",
# {
# ind: f"{mean:.3e}"
# for ind, mean in zip(self.unit_map, means[:, ii])
# },
# )
3 changes: 2 additions & 1 deletion simpeg/directives/pgi_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def endIter(self):
self.pgi_reg.gmm = clfupdate
membership = self.pgi_reg.gmm.predict(model)

if self.fixed_membership is not None:
if clfupdate.fixed_membership is not None:
self.fixed_membership = clfupdate.fixed_membership
membership[self.fixed_membership[:, 0]] = self.fixed_membership[:, 1]

mref = mkvc(self.pgi_reg.gmm.means_[membership])
Expand Down