|
3 | 3 | from eva.utilities.utils import get_schema, update_object |
4 | 4 | import emcpy.plots.map_plots |
5 | 5 | import os |
| 6 | +import numpy as np |
6 | 7 |
|
7 | 8 | from eva.plotting.batch.base.diagnostics.map_gridded import MapGridded |
8 | 9 |
|
|
11 | 12 |
|
12 | 13 | class EmcpyMapGridded(MapGridded): |
13 | 14 | """ |
14 | | - EmcpyMapGridded class is a subclass of the MapGridded class, tailored for |
15 | | - configuring and plotting gridded map visualizations using the emcpy library. |
| 15 | + EMCPy backend for gridded maps. |
| 16 | + Option A: if latitude/longitude are 1-D centers, convert them to 2-D center grids |
| 17 | + with np.meshgrid and reduce data to a single 2-D level before plotting. |
| 18 | + """ |
16 | 19 |
|
17 | | - Attributes: |
18 | | - Inherits attributes from the MapGridded class. |
| 20 | + def _to_2d_centers(self, latvar, lonvar, datavar): |
| 21 | + """ |
| 22 | + Normalize inputs for EMCPy MapGridded by ensuring: |
| 23 | + - lat/lon are 2-D center grids, |
| 24 | + - data is 2-D (matching lat/lon). |
| 25 | + If lat/lon are already 2-D (curvilinear), just squeeze data to 2-D. |
| 26 | + """ |
| 27 | + lat = np.asarray(latvar) |
| 28 | + lon = np.asarray(lonvar) |
| 29 | + A = np.asarray(datavar) |
19 | 30 |
|
20 | | - Methods: |
21 | | - configure_plot(): Configures the plotting settings for the gridded map. |
22 | | - """ |
| 31 | + # If lat/lon are already 2-D curvilinear, just pick/squeeze one level if needed. |
| 32 | + if lat.ndim == 2 and lon.ndim == 2: |
| 33 | + if A.ndim == 3: |
| 34 | + # default to level 0 unless overridden in config |
| 35 | + lev_idx = int(self.config.get("level_index", 0)) |
| 36 | + # choose first axis as level by default |
| 37 | + A = np.squeeze(A[lev_idx, ...]) |
| 38 | + else: |
| 39 | + A = np.squeeze(A) |
| 40 | + return lat, lon, A |
| 41 | + |
| 42 | + # 1-D center coordinates path |
| 43 | + lat1d = lat.squeeze() |
| 44 | + lon1d = lon.squeeze() |
| 45 | + if lat1d.ndim != 1 or lon1d.ndim != 1: |
| 46 | + raise ValueError( |
| 47 | + f"Expected 1-D or 2-D lat/lon; got lat {lat.shape}, lon {lon.shape}" |
| 48 | + ) |
| 49 | + |
| 50 | + # Reduce data to 2-D (Nlat, Nlon) |
| 51 | + if A.ndim == 3: |
| 52 | + # Try to identify lat/lon axes by matching sizes |
| 53 | + shape = A.shape |
| 54 | + lat_axis = next((i for i, s in enumerate(shape) if s == lat1d.size), None) |
| 55 | + lon_axis = next((i for i, s in enumerate(shape) if s == lon1d.size), None) |
| 56 | + if lat_axis is not None and lon_axis is not None: |
| 57 | + order = [lat_axis, lon_axis] + [i for i in range(3) if i not in (lat_axis, lon_axis)] |
| 58 | + A = np.transpose(A, order) |
| 59 | + lev_idx = int(self.config.get("level_index", 0)) |
| 60 | + if A.ndim == 3: |
| 61 | + A = A[:, :, lev_idx] |
| 62 | + else: |
| 63 | + # Fallback: assume first dim is level |
| 64 | + lev_idx = int(self.config.get("level_index", 0)) |
| 65 | + A = A[lev_idx, ...] |
| 66 | + A = np.squeeze(A) |
| 67 | + |
| 68 | + # Ensure orientation (Nlat, Nlon) |
| 69 | + if A.shape != (lat1d.size, lon1d.size): |
| 70 | + if A.T.shape == (lat1d.size, lon1d.size): |
| 71 | + A = A.T |
| 72 | + else: |
| 73 | + raise ValueError( |
| 74 | + f"Data shape {A.shape} incompatible with lat {lat1d.size} / lon {lon1d.size}" |
| 75 | + ) |
| 76 | + |
| 77 | + # Make 2-D center grids |
| 78 | + LAT2D, LON2D = np.meshgrid(lat1d, lon1d, indexing="ij") |
| 79 | + return LAT2D, LON2D, A |
23 | 80 |
|
24 | 81 | def configure_plot(self): |
25 | 82 | """ |
26 | 83 | Configures the plotting settings for the gridded map. |
27 | | -
|
28 | 84 | Returns: |
29 | | - plotobj: The configured plot object for emcpy gridded maps. |
| 85 | + plotobj: The configured plot object for EMCPy gridded maps. |
30 | 86 | """ |
| 87 | + # Convert to 2-D centers + 2-D data if needed |
| 88 | + lat2d, lon2d, data2d = self._to_2d_centers(self.latvar, self.lonvar, self.datavar) |
| 89 | + |
| 90 | + # Create EMCPy MapGridded object |
| 91 | + self.plotobj = emcpy.plots.map_plots.MapGridded(lat2d, lon2d, data2d) |
31 | 92 |
|
32 | | - # create declarative plotting MapGridded object |
33 | | - self.plotobj = emcpy.plots.map_plots.MapGridded(self.latvar, self.lonvar, self.datavar) |
34 | | - # get defaults from schema |
35 | | - layer_schema = self.config.get('schema', os.path.join(return_eva_path(), 'plotting', |
36 | | - 'batch', 'emcpy', 'defaults', 'map_gridded.yaml')) |
| 93 | + # Apply schema defaults/overrides |
| 94 | + layer_schema = self.config.get( |
| 95 | + "schema", |
| 96 | + os.path.join( |
| 97 | + return_eva_path(), "plotting", "batch", "emcpy", "defaults", "map_gridded.yaml" |
| 98 | + ), |
| 99 | + ) |
37 | 100 | new_config = get_schema(layer_schema, self.config, self.logger) |
38 | | - delvars = ['longitude', 'latitude', 'data', 'type', 'schema'] |
39 | | - for d in delvars: |
| 101 | + for d in ["longitude", "latitude", "data", "type", "schema", "level_index"]: |
40 | 102 | new_config.pop(d, None) |
41 | 103 | self.plotobj = update_object(self.plotobj, new_config, self.logger) |
42 | 104 | return self.plotobj |
|
0 commit comments