feat: add precision option (fp32/fp16/bf16)#44
Conversation
Add a precision config field (fp32/fp16/bf16, default fp32) and map it to cccv's fp16/bf16 arguments. bf16 avoids the fp16 numerical overflow that yields NaN output on some transformer models, at the same VRAM savings as fp16. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds configurable numeric precision to SR configuration and wires it through model initialization to enable fp16/bf16 execution modes.
Changes:
- Introduces
precisiontoSRConfigwith validation and a default of"fp32". - Passes
precisionthrough toAutoModel.from_pretrainedviafp16/bf16flags. - Adds tests for valid/invalid precision values and updates config generation to include precision.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| tests/test_config.py | Adds unit tests to validate precision serialization/defaulting and invalid values. |
| scripts/gen_config.py | Includes precision in generated config output. |
| Final2x_core/config.py | Adds precision field and a validator restricting allowed values. |
| Final2x_core/SRclass.py | Maps precision to model init flags (fp16/bf16). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pretrained_model_name: Union[ConfigType, str] | ||
| device: str | ||
| use_tile: Optional[bool] = None | ||
| precision: str = "fp32" # "fp32" | "fp16" | "bf16" |
| @field_validator("precision") | ||
| def precision_match(cls, v: str) -> str: | ||
| precision_list = ["fp32", "fp16", "bf16"] | ||
| if v not in precision_list: | ||
| raise ValueError(f"precision must be one of {precision_list}") | ||
| return v |
| fp16=(self.config.precision == "fp16"), | ||
| bf16=(self.config.precision == "bf16"), |
| def precision_match(cls, v: str) -> str: | ||
| precision_list = ["fp32", "fp16", "bf16"] | ||
| if v not in precision_list: | ||
| raise ValueError(f"precision must be one of {precision_list}") |
| def test_precision_fp16(self) -> None: | ||
| config: SRConfig = SRConfig.from_yaml(CONFIG_PATH) | ||
| config.precision = "fp16" | ||
| config = SRConfig.from_json_str(config.model_dump_json()) | ||
| assert config.precision == "fp16" | ||
|
|
||
| def test_precision_bf16(self) -> None: | ||
| config: SRConfig = SRConfig.from_yaml(CONFIG_PATH) | ||
| config.precision = "bf16" | ||
| config = SRConfig.from_json_str(config.model_dump_json()) | ||
| assert config.precision == "bf16" |
There was a problem hiding this comment.
Code Review
This pull request introduces a new 'precision' configuration option ('fp32', 'fp16', 'bf16') to the super-resolution configuration ('SRConfig') and model initialization, along with corresponding validation and unit tests. Feedback suggests converting the precision input string to lowercase during validation to make it case-insensitive and more robust.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def precision_match(cls, v: str) -> str: | ||
| precision_list = ["fp32", "fp16", "bf16"] | ||
| if v not in precision_list: | ||
| raise ValueError(f"precision must be one of {precision_list}") | ||
| return v |
There was a problem hiding this comment.
To make the configuration more robust and user-friendly, consider converting the input precision string to lowercase before validation. This prevents validation failures if a user specifies uppercase values like FP16 or BF16 in their configuration file.
| def precision_match(cls, v: str) -> str: | |
| precision_list = ["fp32", "fp16", "bf16"] | |
| if v not in precision_list: | |
| raise ValueError(f"precision must be one of {precision_list}") | |
| return v | |
| def precision_match(cls, v: str) -> str: | |
| v = v.lower() | |
| precision_list = ["fp32", "fp16", "bf16"] | |
| if v not in precision_list: | |
| raise ValueError(f"precision must be one of {precision_list}") | |
| return v |
Add a precision config field (fp32/fp16/bf16, default fp32) and map it to cccv's fp16/bf16 arguments. bf16 avoids the fp16 numerical overflow that yields NaN output on some transformer models, at the same VRAM savings as fp16.