Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Final2x_core/SRclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, config: SRConfig) -> None:
self._SR_class: SRBaseModel = AutoModel.from_pretrained(
self.config.pretrained_model_name,
device=get_device(self.config.device),
fp16=False,
fp16=(self.config.precision == "fp16"),
bf16=(self.config.precision == "bf16"),
Comment on lines +29 to +30
tile=tile,
gh_proxy=self.config.gh_proxy,
)
Expand Down
8 changes: 8 additions & 0 deletions Final2x_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SRConfig(BaseModel):
pretrained_model_name: Union[ConfigType, str]
device: str
use_tile: Optional[bool] = None
precision: str = "fp32" # "fp32" | "fp16" | "bf16"
gh_proxy: Optional[str] = None
target_scale: Optional[Union[int, float]] = None
output_path: DirectoryPath
Expand Down Expand Up @@ -61,3 +62,10 @@ def device_match(cls, v: str) -> str:
return v

raise ValueError(f"device must start with {device_list}")

@field_validator("precision")
def precision_match(cls, v: str) -> str:
precision_list = ["fp32", "fp16", "bf16"]
if v not in precision_list:
raise ValueError(f"precision must be one of {precision_list}")
return v
Comment on lines +66 to +71
Comment on lines +67 to +71

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make the configuration more robust and user-friendly, consider converting the input precision string to lowercase before validation. This prevents validation failures if a user specifies uppercase values like FP16 or BF16 in their configuration file.

Suggested change
def precision_match(cls, v: str) -> str:
precision_list = ["fp32", "fp16", "bf16"]
if v not in precision_list:
raise ValueError(f"precision must be one of {precision_list}")
return v
def precision_match(cls, v: str) -> str:
v = v.lower()
precision_list = ["fp32", "fp16", "bf16"]
if v not in precision_list:
raise ValueError(f"precision must be one of {precision_list}")
return v

1 change: 1 addition & 0 deletions scripts/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def gen_config() -> None:
"pretrained_model_name": ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x.value,
"device": _DEVICE_,
"use_tile": True,
"precision": "fp32",
"gh_proxy": None,
"target_scale": None,
"output_path": str(projectPATH / "assets"),
Expand Down
25 changes: 25 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ def test_from_base64(self) -> None:
config = SRConfig.from_base64(b_str)
print(config)

def test_precision_fp16(self) -> None:
config: SRConfig = SRConfig.from_yaml(CONFIG_PATH)
config.precision = "fp16"
config = SRConfig.from_json_str(config.model_dump_json())
assert config.precision == "fp16"

def test_precision_bf16(self) -> None:
config: SRConfig = SRConfig.from_yaml(CONFIG_PATH)
config.precision = "bf16"
config = SRConfig.from_json_str(config.model_dump_json())
assert config.precision == "bf16"
Comment on lines +29 to +39

def test_precision_default(self) -> None:
config: SRConfig = SRConfig.from_yaml(CONFIG_PATH)
config_dict = config.model_dump()
config_dict.pop("precision", None)
config = SRConfig(**config_dict)
assert config.precision == "fp32"

def test_precision_invalid(self) -> None:
with pytest.raises(ValueError):
config: SRConfig = SRConfig.from_yaml(CONFIG_PATH)
config.precision = "invalid"
SRConfig.from_json_str(config.model_dump_json())

def test_error_device(self) -> None:
config: SRConfig
with pytest.raises(ValueError):
Expand Down
Loading