@@ -27,7 +27,7 @@ def _to_2d_centers(self, latvar, lonvar, datavar):
2727 """
2828 lat = np .asarray (latvar )
2929 lon = np .asarray (lonvar )
30- A = np .asarray (datavar )
30+ A = np .asarray (datavar )
3131
3232 # --- CASE 1: 3-D curvilinear (lon, lat, tile) → choose one tile to get 2-D ---
3333 if lat .ndim == 3 and lon .ndim == 3 and lat .shape == lon .shape :
@@ -72,7 +72,8 @@ def _to_2d_centers(self, latvar, lonvar, datavar):
7272 if A2 .T .shape == lat2d .shape :
7373 A2 = A2 .T
7474 else :
75- raise ValueError (f"Data shape { A2 .shape } incompatible with tile lat/lon { lat2d .shape } " )
75+ raise ValueError (f"Data shape { A2 .shape } incompatible "
76+ f"with tile lat/lon { lat2d .shape } " )
7677
7778 return lat2d , lon2d , A2
7879
@@ -96,7 +97,12 @@ def _to_2d_centers(self, latvar, lonvar, datavar):
9697 lat_axis = next ((i for i , s in enumerate (shape ) if s == lat1d .size ), None )
9798 lon_axis = next ((i for i , s in enumerate (shape ) if s == lon1d .size ), None )
9899 if lat_axis is not None and lon_axis is not None :
99- order = [lat_axis , lon_axis ] + [i for i in range (3 ) if i not in (lat_axis , lon_axis )]
100+ axes = (lat_axis , lon_axis )
101+ extra = [
102+ i for i in range (A2 .ndim )
103+ if i not in axes
104+ ]
105+ order = [* axes , * extra ]
100106 A2 = np .transpose (A2 , order )
101107 lev_idx = int (self .config .get ("level_index" , 0 ))
102108 if A2 .ndim == 3 :
@@ -112,7 +118,8 @@ def _to_2d_centers(self, latvar, lonvar, datavar):
112118 A2 = A2 .T
113119 else :
114120 raise ValueError (
115- f"Data shape { A2 .shape } incompatible with lat { lat1d .size } / lon { lon1d .size } "
121+ f"Data shape { A2 .shape } incompatible with "
122+ f"lat { lat1d .size } / lon { lon1d .size } "
116123 )
117124
118125 LAT2D , LON2D = np .meshgrid (lat1d , lon1d , indexing = "ij" )
0 commit comments