Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions sup3r/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sup3r.utilities import VERSION_RECORD
from sup3r.utilities.utilities import safe_cast

from .utilities import SUP3R_EXO_LAYERS, SUP3R_OBS_LAYERS
from .utilities import SUP3R_LAYERS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -355,6 +355,20 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
hi_res = np.concatenate((hi_res, exo_output), axis=-1)
return hi_res

def _get_layer_features(self):
"""Get the list of features used in the model based on layer
attributes. This is used to check that the features provided in
exogenous_data match the features expected by the model
architecture."""
features = []
if hasattr(self, '_gen'):
for layer in self._gen.layers:
if isinstance(layer, SUP3R_LAYERS):
layer_feats = getattr(layer, 'features', [layer.name])
layer_feats = [f for f in layer_feats if f not in features]
features.extend(layer_feats)
return features

@property
@abstractmethod
def meta(self):
Expand All @@ -377,16 +391,9 @@ def hr_out_features(self):
@property
def obs_features(self):
"""Get list of exogenous observation feature names the model uses.
These come from the names of the ``Sup3rObs..`` layers."""
# pylint: disable=E1101
features = []
if hasattr(self, '_gen'):
for layer in self._gen.layers:
if isinstance(layer, SUP3R_OBS_LAYERS):
obs_feats = getattr(layer, 'features', [layer.name])
obs_feats = [f for f in obs_feats if f not in features]
features.extend(obs_feats)
return features
These are the features with an '_obs' suffix"""
features = self._get_layer_features()
return [f for f in features if '_obs' in f]

@property
def hr_exo_features(self):
Expand All @@ -397,14 +404,8 @@ def hr_exo_features(self):
[..., topo, sza], and the model has 2 concat or add layers, exo
features will be [topo, sza]. Topo will then be used in the first
concat layer and sza will be used in the second"""
# pylint: disable=E1101
features = []
if hasattr(self, '_gen'):
features = [
layer.name
for layer in self._gen.layers
if isinstance(layer, SUP3R_EXO_LAYERS)
]
features = self._get_layer_features()
features = [f for f in features if '_obs' not in f]
obs_feats = [feat.replace('_obs', '') for feat in self.obs_features]
features += [f for f in obs_feats if f not in self.hr_out_features]
return features
Expand Down
17 changes: 12 additions & 5 deletions sup3r/models/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Sup3rAdder,
Sup3rConcat,
Sup3rConcatObs,
Sup3rCrossAttention,
Sup3rObsModel,
)
from scipy.interpolate import RegularGridInterpolator
Expand All @@ -20,11 +21,17 @@

logger = logging.getLogger(__name__)

SUP3R_OBS_LAYERS = Sup3rObsModel, Sup3rConcatObs

SUP3R_EXO_LAYERS = Sup3rAdder, Sup3rConcat

SUP3R_LAYERS = (*SUP3R_EXO_LAYERS, *SUP3R_OBS_LAYERS)
# These are special layers that are used to injest exogenous data and
# observations. They are checked for feature or name attributes to
# determine what features the model uses as exogenous inputs and what
# features the model uses as observation inputs.
SUP3R_LAYERS = (
Sup3rObsModel,
Sup3rConcatObs,
Sup3rCrossAttention,
Sup3rAdder,
Sup3rConcat,
)


class TrainingSession:
Expand Down
Loading