55 FluxImagePipeline ,
66 FluxRedux ,
77 fetch_model ,
8+ FluxStateDicts
89)
9- from typing import List , Tuple , Optional , Callable
10+ from typing import List , Tuple , Optional , Callable , Dict
1011from PIL import Image
1112import torch
1213
@@ -19,8 +20,19 @@ class FluxReplaceByControlTool:
1920
2021 def __init__ (
2122 self ,
23+ flux_pipe : FluxImagePipeline ,
24+ redux : FluxRedux ,
25+ controlnet : FluxControlNet ,
26+ ):
27+ self .pipe = flux_pipe
28+ self .pipe .load_redux (redux )
29+ self .controlnet = controlnet
30+
31+ @classmethod
32+ def from_pretrained (
33+ cls ,
2234 flux_model_path : str ,
23- load_text_encoder = True ,
35+ load_text_encoder : bool = True ,
2436 device : str = "cuda:0" ,
2537 dtype : torch .dtype = torch .bfloat16 ,
2638 offload_mode : Optional [str ] = None ,
@@ -32,17 +44,42 @@ def __init__(
3244 device = device ,
3345 offload_mode = offload_mode ,
3446 )
35- self . pipe : FluxImagePipeline = FluxImagePipeline .from_pretrained (config )
47+ flux_pipe = FluxImagePipeline .from_pretrained (config )
3648 redux_model_path = fetch_model ("muse/flux1-redux-dev" , path = "flux1-redux-dev.safetensors" , revision = "v1" )
37- flux_redux = FluxRedux .from_pretrained (redux_model_path , device = device )
38- self .pipe .load_redux (flux_redux )
39- self .controlnet = FluxControlNet .from_pretrained (
49+ redux = FluxRedux .from_pretrained (redux_model_path , device = device , dtype = dtype )
50+ controlnet = FluxControlNet .from_pretrained (
4051 fetch_model (
41- "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta" , path = "diffusion_pytorch_model.safetensors"
52+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta" ,
53+ path = "diffusion_pytorch_model.safetensors"
4254 ),
4355 device = device ,
4456 dtype = torch .bfloat16 ,
4557 )
58+ return cls (flux_pipe , redux , controlnet )
59+
60+ @classmethod
61+ def from_state_dict (
62+ cls ,
63+ flux_state_dicts : FluxStateDicts ,
64+ redux_state_dict : Dict [str , torch .Tensor ],
65+ controlnet_state_dict : Dict [str , torch .Tensor ],
66+ load_text_encoder : bool = True ,
67+ device : str = "cuda:0" ,
68+ dtype : torch .dtype = torch .bfloat16 ,
69+ offload_mode : Optional [str ] = None ,
70+ ):
71+ config = FluxPipelineConfig (
72+ model_path = "" ,
73+ model_dtype = dtype ,
74+ load_text_encoder = load_text_encoder ,
75+ device = device ,
76+ offload_mode = offload_mode ,
77+ )
78+ flux_pipe = FluxImagePipeline .from_state_dict (flux_state_dicts , config )
79+ redux = FluxRedux .from_state_dict (redux_state_dict , device = device , dtype = dtype )
80+ controlnet = FluxControlNet .from_state_dict (controlnet_state_dict , device = device , dtype = dtype )
81+ return cls (flux_pipe , redux , controlnet )
82+
4683
4784 def load_loras (self , lora_list : List [Tuple [str , float ]], fused : bool = True , save_original_weight : bool = False ):
4885 self .pipe .load_loras (lora_list , fused , save_original_weight )
0 commit comments