Skip to content

Commit 76aea84

Browse files
committed
replace dtype with compatible torch_dtype for transformers v4.53.0
1 parent 205bdf4 commit 76aea84

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

scripts/eval_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def main():
9494

9595
# Loop: models -> strategies -> subsets
9696
for strategy_name in args.strategies:
97-
if strategy_name == "efficient_kascade" and model.config.dtype != torch.float16:
97+
if strategy_name == "efficient_kascade" and model.config.torch_dtype != torch.float16:
9898
raise ValueError("Efficient Kascade strategy requires model to be in float16 precision. Please go to line 17 in src/model_utils.py and change torch_dtype=torch.float16 when loading the model for running with efficient_kascade.")
9999

100100
set_seed(args.seed) # Ensure reproducibility per run

0 commit comments

Comments
 (0)