@@ -419,9 +419,10 @@ class ControlType(Enum):
419419 normal = "normal"
420420 bfl_control = "bfl_control"
421421 bfl_fill = "bfl_fill"
422+ bfl_kontext = "bfl_kontext"
422423
423424 def get_in_channel (self ):
424- if self == ControlType .normal :
425+ if self in [ ControlType .normal , ControlType . bfl_kontext ] :
425426 return 64
426427 elif self == ControlType .bfl_control :
427428 return 128
@@ -764,9 +765,15 @@ def predict_noise(
764765 current_step : int ,
765766 total_step : int ,
766767 ):
768+ origin_latents_shape = latents .shape
767769 if self .control_type != ControlType .normal :
768770 controlnet_param = controlnet_params [0 ]
769- latents = torch .cat ((latents , controlnet_param .image * controlnet_param .scale ), dim = 1 )
771+ if self .control_type == ControlType .bfl_kontext :
772+ latents = torch .cat ((latents , controlnet_param .image * controlnet_param .scale ), dim = 2 )
773+ image_ids = image_ids .repeat (1 , 2 , 1 )
774+ image_ids [:, image_ids .shape [1 ] // 2 :, 0 ] += 1
775+ else :
776+ latents = torch .cat ((latents , controlnet_param .image * controlnet_param .scale ), dim = 1 )
770777 latents = latents .to (self .dtype )
771778 controlnet_params = []
772779
@@ -797,6 +804,8 @@ def predict_noise(
797804 controlnet_double_block_output = double_block_output ,
798805 controlnet_single_block_output = single_block_output ,
799806 )
807+ if self .control_type == ControlType .bfl_kontext :
808+ noise_pred = noise_pred [:, :, : origin_latents_shape [2 ], : origin_latents_shape [3 ]]
800809 return noise_pred
801810
802811 def prepare_latents (
0 commit comments