Gen2 Trainer (and other derivative functions) do not support levels: null for selecting all levels, despite LocalDataset supporting it
When setting the levels param to null in the config, such as the following:
data:
source:
CAM6:
dataset_type: "local"
level_coord: "lev"
levels: null
The LocalDataset correctly reads it as selecting all available layers in this snippit of code:
|
for vname in vars_3D: |
|
if self.levels is None: |
|
arr = ds_t[vname].values |
|
if self.level_coord in ds_t.coords: |
|
self.levels = ds_t[self.level_coord].values.tolist() |
|
self.static_metadata["levels"] = self.levels |
However, when passed through the Gen2 Trainer, it throws an error here due to the param levels still existing (even though it is of value null), causing the default value to not be invoked, then attempting to get the length of a NoneType object:
|
num_levels = len(source.get("levels", [])) |
If you instead pass levels: [], LocalDataset will interpret it as wanting 0 levels, as [] is not None, and the Gen2 Trainer also interprets it as 0 levels, meaning it won't crash, however it will think there are no vertical levels (which will cause crashes later in the pipeline)
Finally, if you instead completely omit the levels: param from the config, LocalDataset will correctly load in all the vertical levels; however, the trainer again defaults to the [] value, making it think that there are 0 vertical levels (as shown below), and that incorrect value is passed on to many other functions, leading to buggy behavior:
|
self.varnum_diag = (len(diag.get("vars_3D", [])) * num_levels + len(diag.get("vars_2D", []))) if diag else 0 |
The following are some of the functions impacted by the incorrect count of vertical levels:
preflight.py (see lines 62-66, 84-85):
If levels: null, memory estimate will be 0 (silent fail). If levels: [] or it is omitted, then memory will be estimated as if there were no vertical levels.
|
try: |
|
trainer_conf = conf.get("trainer", {}) |
|
data_conf = conf.get("data", {}) |
|
model_conf = conf.get("model", {}) |
|
src = next(iter(data_conf.get("source", {}).values()), {}) |
|
v = src.get("variables", {}) |
|
prog = v.get("prognostic") or {} |
|
diag = v.get("diagnostic") or {} |
|
|
|
n_levels = len(src.get("levels", [])) |
|
n_vars_3d = len(prog.get("vars_3D", [])) |
|
n_vars_2d = len(prog.get("vars_2D", [])) |
|
n_diag_2d = len(diag.get("vars_2D", [])) |
|
total_ch = n_vars_3d * n_levels + n_vars_2d + n_diag_2d |
|
|
|
if total_ch == 0: |
|
return 0.0 |
|
|
|
H = model_conf.get("image_height", 721) |
|
W = model_conf.get("image_width", 1440) |
|
|
|
bytes_per_sample = H * W * total_ch * 4 # float32 4 bytes |
|
bytes_per_sample *= 2 # input + target |
|
|
|
workers = trainer_conf.get("thread_workers", 4) |
|
prefetch = trainer_conf.get("prefetch_factor", 4) |
|
batch_size = trainer_conf.get("train_batch_size", 1) |
|
|
|
total_bytes = workers * prefetch * batch_size * bytes_per_sample |
|
return total_bytes / 2**30 |
|
|
|
except Exception: |
|
return 0.0 |
channel_layout.py:
If levels: null, len(None) returns a TypeError. If levels: [], all the channel slices will be empty. If levels is omitted, a KeyError is returned.
|
src_conf = next(iter(conf["data"]["source"].values())) |
|
n_levels = len(src_conf["levels"]) |
rollout_to_netcdf_gen2.py
If levels: null, level_ids will be None, where it should be a list. If levels: [], level_ids will be an empty list. If levels is omitted, then there will be int indices (the levels used in that range comes from somewhere else in the config).
|
conf["data"]["level_ids"] = src.get("levels", list(range(conf["model"]["levels"]))) |
Same
len(None) or n_levels=0 error as the Gen2 Trainer code
|
n_levels = len(src.get("levels", [])) |
If
levels: null,
len(None) returns a
TypeError. If
levels: [],
levels will incorrectly be 0. If
levels is omitted, a
KeyError is returned.
|
levels = src["levels"] |
|
level_coord = src["level_coord"] |
|
n_levels = len(levels) |
Error is propigated to here from the previous lines 111-113, and this requires the actual pressure coordinate list (which could be an issue for floats?)
|
if is_3d: |
|
m = torch.tensor(mean_ds[varname].sel({level_coord: levels}).values, dtype=dtype) |
|
s = torch.tensor(std_ds[varname].sel({level_coord: levels}).values, dtype=dtype) |
Same issues as line 84
|
n_levels = len(src["levels"]) |
|
varnum_diag = len(diag.get("vars_2D", [])) + len(diag.get("vars_3D", [])) * n_levels |
rollout_realtime_gen2.py (the rest of these are very similar to previous examples, so full explanations won't be included)
|
conf["data"]["level_ids"] = src.get("levels", list(range(conf["model"]["levels"]))) |
|
n_levels = len(src.get("levels", [])) |
|
levels = src["levels"] |
|
level_coord = src["level_coord"] |
|
n_levels = len(levels) |
|
if is_3d: |
|
m = torch.tensor(mean_ds[varname].sel({level_coord: levels}).values, dtype=dtype) |
|
s = torch.tensor(std_ds[varname].sel({level_coord: levels}).values, dtype=dtype) |
(function call of previously explained error)
|
slices, _ = build_channel_layout(conf) |
_plot.py
|
n_levels = len(src.get("levels", [])) |
|
levels = src["levels"] |
|
level_coord = src["level_coord"] |
|
n_levels = len(levels) |
|
if is_3d: |
|
m = mean_ds[varname].sel({level_coord: levels}).values.astype(np.float32) |
|
s = std_ds[varname].sel({level_coord: levels}).values.astype(np.float32) |
Possible Solutions
The issue boils down to the following table on the different ways we grab the levels and the different inputs for the levels:
|
get("levels") |
get("levels", []) |
["levels"] |
| omitted |
None |
[] |
KeyError |
null |
None |
None |
None |
[] |
[] |
[] |
[] |
The intuitive option would be to make levels: null the official way of selecting all pressure levels, as it has consistent behavior. In addition, the model part of the config has the number of levels included, which can be used instead of len(cfg['levels']). Some sort of helper functions like the following may also prove useful:
def resolve_num_levels(source: dict, conf: dict) -> int:
levels = source.get("levels")
if levels is not None:
return len(levels)
return int(conf.get("model", {}).get("levels", 0))
def resolve_level_ids(source: dict, conf: dict) -> list:
levels = source.get("levels")
if levels is not None:
return levels
return list(range(conf["model"]["levels"])) # or dataset.static_metadata["levels"]
Concluding Remarks
There could be other fixes for this or something built in to the codebase to handle this, but I looked around and couldn't find anything. This is a mostly blocking issue for anyone wanting to use a LocalDataset and not include explicit pressure levels, which is the case many times if they are float values rather than ints. If you want easy debugging/testing code, the credit-chem repo (found here) has my code in which I ran into this issue, just run uv run python scripts/train.py -c config/credit-chem-cubed-v0_1.yaml locally or bash bash/submit_train.sh -c config/credit-chem-cubed-v0_1.yaml for submitting a job.
Gen2 Trainer (and other derivative functions) do not support
levels: nullfor selecting all levels, despite LocalDataset supporting itWhen setting the levels param to null in the config, such as the following:
The
LocalDatasetcorrectly reads it as selecting all available layers in this snippit of code:miles-credit/credit/datasets/local.py
Lines 181 to 186 in 595cf70
However, when passed through the Gen2 Trainer, it throws an error here due to the param
levelsstill existing (even though it is of value null), causing the default value to not be invoked, then attempting to get the length of a NoneType object:miles-credit/credit/trainers/trainer_gen2.py
Line 57 in 595cf70
If you instead pass
levels: [],LocalDatasetwill interpret it as wanting 0 levels, as[]is notNone, and the Gen2 Trainer also interprets it as 0 levels, meaning it won't crash, however it will think there are no vertical levels (which will cause crashes later in the pipeline)Finally, if you instead completely omit the
levels:param from the config,LocalDatasetwill correctly load in all the vertical levels; however, the trainer again defaults to the[]value, making it think that there are 0 vertical levels (as shown below), and that incorrect value is passed on to many other functions, leading to buggy behavior:miles-credit/credit/trainers/trainer_gen2.py
Line 58 in 95861e7
The following are some of the functions impacted by the incorrect count of vertical levels:
preflight.py(see lines 62-66, 84-85):If
levels: null, memory estimate will be 0 (silent fail). Iflevels: []or it is omitted, then memory will be estimated as if there were no vertical levels.miles-credit/credit/trainers/preflight.py
Lines 53 to 85 in 95861e7
channel_layout.py:If
levels: null,len(None)returns aTypeError. Iflevels: [], all the channel slices will be empty. Iflevelsis omitted, aKeyErroris returned.miles-credit/credit/datasets/channel_layout.py
Lines 61 to 62 in 95861e7
rollout_to_netcdf_gen2.pyIf
levels: null,level_idswill be None, where it should be a list. Iflevels: [],level_idswill be an empty list. Iflevelsis omitted, then there will be int indices (thelevelsused in that range comes from somewhere else in the config).miles-credit/credit/applications/rollout_to_netcdf_gen2.py
Line 72 in 95861e7
Same
len(None)or n_levels=0 error as the Gen2 Trainer codemiles-credit/credit/applications/rollout_to_netcdf_gen2.py
Line 84 in 95861e7
If
levels: null,len(None)returns aTypeError. Iflevels: [],levelswill incorrectly be 0. Iflevelsis omitted, aKeyErroris returned.miles-credit/credit/applications/rollout_to_netcdf_gen2.py
Lines 111 to 113 in 95861e7
Error is propigated to here from the previous lines 111-113, and this requires the actual pressure coordinate list (which could be an issue for floats?)
miles-credit/credit/applications/rollout_to_netcdf_gen2.py
Lines 126 to 128 in 95861e7
Same issues as line 84
miles-credit/credit/applications/rollout_to_netcdf_gen2.py
Lines 214 to 215 in 95861e7
rollout_realtime_gen2.py(the rest of these are very similar to previous examples, so full explanations won't be included)miles-credit/credit/applications/rollout_realtime_gen2.py
Line 74 in 95861e7
miles-credit/credit/applications/rollout_realtime_gen2.py
Line 85 in 95861e7
miles-credit/credit/applications/rollout_realtime_gen2.py
Lines 106 to 108 in 95861e7
miles-credit/credit/applications/rollout_realtime_gen2.py
Lines 121 to 123 in 95861e7
(function call of previously explained error)
miles-credit/credit/applications/rollout_realtime_gen2.py
Line 266 in 95861e7
_plot.pymiles-credit/credit/cli/_plot.py
Line 15 in 95861e7
miles-credit/credit/cli/_plot.py
Lines 40 to 42 in 95861e7
miles-credit/credit/cli/_plot.py
Lines 64 to 66 in 95861e7
Possible Solutions
The issue boils down to the following table on the different ways we grab the levels and the different inputs for the levels:
get("levels")get("levels", [])["levels"]None[]KeyErrornullNoneNoneNone[][][][]The intuitive option would be to make
levels: nullthe official way of selecting all pressure levels, as it has consistent behavior. In addition, the model part of the config has the number of levels included, which can be used instead oflen(cfg['levels']). Some sort of helper functions like the following may also prove useful:Concluding Remarks
There could be other fixes for this or something built in to the codebase to handle this, but I looked around and couldn't find anything. This is a mostly blocking issue for anyone wanting to use a
LocalDatasetand not include explicit pressure levels, which is the case many times if they are float values rather than ints. If you want easy debugging/testing code, the credit-chem repo (found here) has my code in which I ran into this issue, just runuv run python scripts/train.py -c config/credit-chem-cubed-v0_1.yamllocally orbash bash/submit_train.sh -c config/credit-chem-cubed-v0_1.yamlfor submitting a job.