Skip to content

Gen2 Trainer (and other derivative functions) do not support levels: null for selecting all levels, despite LocalDataset supporting it #432

Description

@lj-dunphy

Gen2 Trainer (and other derivative functions) do not support levels: null for selecting all levels, despite LocalDataset supporting it

When setting the levels param to null in the config, such as the following:

data:
  source:
    CAM6:
      dataset_type: "local"
      level_coord: "lev"
      levels: null

The LocalDataset correctly reads it as selecting all available layers in this snippit of code:

for vname in vars_3D:
if self.levels is None:
arr = ds_t[vname].values
if self.level_coord in ds_t.coords:
self.levels = ds_t[self.level_coord].values.tolist()
self.static_metadata["levels"] = self.levels

However, when passed through the Gen2 Trainer, it throws an error here due to the param levels still existing (even though it is of value null), causing the default value to not be invoked, then attempting to get the length of a NoneType object:

num_levels = len(source.get("levels", []))

If you instead pass levels: [], LocalDataset will interpret it as wanting 0 levels, as [] is not None, and the Gen2 Trainer also interprets it as 0 levels, meaning it won't crash, however it will think there are no vertical levels (which will cause crashes later in the pipeline)

Finally, if you instead completely omit the levels: param from the config, LocalDataset will correctly load in all the vertical levels; however, the trainer again defaults to the [] value, making it think that there are 0 vertical levels (as shown below), and that incorrect value is passed on to many other functions, leading to buggy behavior:

self.varnum_diag = (len(diag.get("vars_3D", [])) * num_levels + len(diag.get("vars_2D", []))) if diag else 0

The following are some of the functions impacted by the incorrect count of vertical levels:

preflight.py (see lines 62-66, 84-85):

If levels: null, memory estimate will be 0 (silent fail). If levels: [] or it is omitted, then memory will be estimated as if there were no vertical levels.

try:
trainer_conf = conf.get("trainer", {})
data_conf = conf.get("data", {})
model_conf = conf.get("model", {})
src = next(iter(data_conf.get("source", {}).values()), {})
v = src.get("variables", {})
prog = v.get("prognostic") or {}
diag = v.get("diagnostic") or {}
n_levels = len(src.get("levels", []))
n_vars_3d = len(prog.get("vars_3D", []))
n_vars_2d = len(prog.get("vars_2D", []))
n_diag_2d = len(diag.get("vars_2D", []))
total_ch = n_vars_3d * n_levels + n_vars_2d + n_diag_2d
if total_ch == 0:
return 0.0
H = model_conf.get("image_height", 721)
W = model_conf.get("image_width", 1440)
bytes_per_sample = H * W * total_ch * 4 # float32 4 bytes
bytes_per_sample *= 2 # input + target
workers = trainer_conf.get("thread_workers", 4)
prefetch = trainer_conf.get("prefetch_factor", 4)
batch_size = trainer_conf.get("train_batch_size", 1)
total_bytes = workers * prefetch * batch_size * bytes_per_sample
return total_bytes / 2**30
except Exception:
return 0.0

channel_layout.py:

If levels: null, len(None) returns a TypeError. If levels: [], all the channel slices will be empty. If levels is omitted, a KeyError is returned.

src_conf = next(iter(conf["data"]["source"].values()))
n_levels = len(src_conf["levels"])

rollout_to_netcdf_gen2.py

If levels: null, level_ids will be None, where it should be a list. If levels: [], level_ids will be an empty list. If levels is omitted, then there will be int indices (the levels used in that range comes from somewhere else in the config).

conf["data"]["level_ids"] = src.get("levels", list(range(conf["model"]["levels"])))

Same len(None) or n_levels=0 error as the Gen2 Trainer code
n_levels = len(src.get("levels", []))

If levels: null, len(None) returns a TypeError. If levels: [], levels will incorrectly be 0. If levels is omitted, a KeyError is returned.
levels = src["levels"]
level_coord = src["level_coord"]
n_levels = len(levels)

Error is propigated to here from the previous lines 111-113, and this requires the actual pressure coordinate list (which could be an issue for floats?)
if is_3d:
m = torch.tensor(mean_ds[varname].sel({level_coord: levels}).values, dtype=dtype)
s = torch.tensor(std_ds[varname].sel({level_coord: levels}).values, dtype=dtype)

Same issues as line 84
n_levels = len(src["levels"])
varnum_diag = len(diag.get("vars_2D", [])) + len(diag.get("vars_3D", [])) * n_levels

rollout_realtime_gen2.py (the rest of these are very similar to previous examples, so full explanations won't be included)

conf["data"]["level_ids"] = src.get("levels", list(range(conf["model"]["levels"])))

n_levels = len(src.get("levels", []))

levels = src["levels"]
level_coord = src["level_coord"]
n_levels = len(levels)

if is_3d:
m = torch.tensor(mean_ds[varname].sel({level_coord: levels}).values, dtype=dtype)
s = torch.tensor(std_ds[varname].sel({level_coord: levels}).values, dtype=dtype)

(function call of previously explained error)
slices, _ = build_channel_layout(conf)

_plot.py

n_levels = len(src.get("levels", []))

levels = src["levels"]
level_coord = src["level_coord"]
n_levels = len(levels)

if is_3d:
m = mean_ds[varname].sel({level_coord: levels}).values.astype(np.float32)
s = std_ds[varname].sel({level_coord: levels}).values.astype(np.float32)

Possible Solutions

The issue boils down to the following table on the different ways we grab the levels and the different inputs for the levels:

get("levels") get("levels", []) ["levels"]
omitted None [] KeyError
null None None None
[] [] [] []

The intuitive option would be to make levels: null the official way of selecting all pressure levels, as it has consistent behavior. In addition, the model part of the config has the number of levels included, which can be used instead of len(cfg['levels']). Some sort of helper functions like the following may also prove useful:

def resolve_num_levels(source: dict, conf: dict) -> int:
    levels = source.get("levels")
    if levels is not None:
        return len(levels)
    return int(conf.get("model", {}).get("levels", 0))

def resolve_level_ids(source: dict, conf: dict) -> list:
    levels = source.get("levels")
    if levels is not None:
        return levels
    return list(range(conf["model"]["levels"]))  # or dataset.static_metadata["levels"]

Concluding Remarks

There could be other fixes for this or something built in to the codebase to handle this, but I looked around and couldn't find anything. This is a mostly blocking issue for anyone wanting to use a LocalDataset and not include explicit pressure levels, which is the case many times if they are float values rather than ints. If you want easy debugging/testing code, the credit-chem repo (found here) has my code in which I ran into this issue, just run uv run python scripts/train.py -c config/credit-chem-cubed-v0_1.yaml locally or bash bash/submit_train.sh -c config/credit-chem-cubed-v0_1.yaml for submitting a job.

Metadata

Metadata

Labels

No labels
No labels

Type

No fields configured for Bug.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions