@@ -128,10 +128,19 @@ def __init__(
128128 vae : WanVideoVAE ,
129129 image_encoder : WanImageEncoder ,
130130 batch_cfg : bool = False ,
131+ vae_tiled : bool = True ,
132+ vae_tile_size : Tuple [int , int ] = (34 , 34 ),
133+ vae_tile_stride : Tuple [int , int ] = (18 , 16 ),
131134 device = "cuda" ,
132135 dtype = torch .bfloat16 ,
133136 ):
134- super ().__init__ (device = device , dtype = dtype )
137+ super ().__init__ (
138+ vae_tiled = vae_tiled ,
139+ vae_tile_size = vae_tile_size ,
140+ vae_tile_stride = vae_tile_stride ,
141+ device = device ,
142+ dtype = dtype ,
143+ )
135144 self .noise_scheduler = RecifitedFlowScheduler (shift = 5.0 , sigma_min = 0.001 , sigma_max = 0.999 )
136145 self .sampler = FlowMatchEulerSampler ()
137146 self .tokenizer = tokenizer
@@ -202,22 +211,26 @@ def tensor2video(self, frames):
202211 frames = [Image .fromarray (frame ) for frame in frames ]
203212 return frames
204213
205- def encode_video (self , videos : torch .Tensor , tiled = True , tile_size = ( 34 , 34 ), tile_stride = ( 18 , 16 ) ):
214+ def encode_video (self , videos : torch .Tensor ):
206215 videos = videos .to (dtype = self .config .vae_dtype , device = self .device )
207- latents = self .vae .encode (videos , device = self .device , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )
216+ latents = self .vae .encode (
217+ videos ,
218+ device = self .device ,
219+ tiled = self .vae_tiled ,
220+ tile_size = self .vae_tile_size ,
221+ tile_stride = self .vae_tile_stride ,
222+ )
208223 latents = latents .to (dtype = self .config .dit_dtype , device = self .device )
209224 return latents
210225
211- def decode_video (
212- self , latents , tiled = True , tile_size = (34 , 34 ), tile_stride = (18 , 16 ), progress_callback = None
213- ) -> List [torch .Tensor ]:
226+ def decode_video (self , latents , progress_callback = None ) -> List [torch .Tensor ]:
214227 latents = latents .to (dtype = self .config .vae_dtype , device = self .device )
215228 videos = self .vae .decode (
216229 latents ,
217230 device = self .device ,
218- tiled = tiled ,
219- tile_size = tile_size ,
220- tile_stride = tile_stride ,
231+ tiled = self . vae_tiled ,
232+ tile_size = self . vae_tile_size ,
233+ tile_stride = self . vae_tile_stride ,
221234 progress_callback = progress_callback ,
222235 )
223236 videos = [video .to (dtype = self .config .dit_dtype , device = self .device ) for video in videos ]
@@ -297,9 +310,6 @@ def prepare_latents(
297310 input_video ,
298311 denoising_strength ,
299312 num_inference_steps ,
300- tiled = True ,
301- tile_size = (34 , 34 ),
302- tile_stride = (18 , 16 ),
303313 ):
304314 if input_video is not None :
305315 total_steps = num_inference_steps
@@ -311,9 +321,7 @@ def prepare_latents(
311321 noise = latents
312322 input_video = self .preprocess_images (input_video )
313323 input_video = torch .stack (input_video , dim = 2 )
314- latents = self .encode_video (input_video , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride ).to (
315- dtype = latents .dtype , device = latents .device
316- )
324+ latents = self .encode_video (input_video ).to (dtype = latents .dtype , device = latents .device )
317325 init_latents = latents .clone ()
318326 latents = self .sampler .add_noise (latents , noise , sigma_start )
319327 else :
@@ -336,9 +344,6 @@ def __call__(
336344 num_frames = 81 ,
337345 cfg_scale = 5.0 ,
338346 num_inference_steps = 50 ,
339- tiled = True ,
340- tile_size = (34 , 34 ),
341- tile_stride = (18 , 16 ),
342347 progress_callback : Optional [Callable ] = None , # def progress_callback(current, total, status)
343348 ):
344349 assert height % 16 == 0 and width % 16 == 0 , "height and width must be divisible by 16"
@@ -353,9 +358,6 @@ def __call__(
353358 input_video ,
354359 denoising_strength ,
355360 num_inference_steps ,
356- tiled = tiled ,
357- tile_size = tile_size ,
358- tile_stride = tile_stride ,
359361 )
360362 self .sampler .initialize (init_latents = init_latents , timesteps = timesteps , sigmas = sigmas )
361363 # Encode prompts
@@ -392,9 +394,7 @@ def __call__(
392394
393395 # Decode
394396 self .load_models_to_device (["vae" ])
395- frames = self .decode_video (
396- latents , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride , progress_callback = progress_callback
397- )
397+ frames = self .decode_video (latents , progress_callback = progress_callback )
398398 frames = self .tensor2video (frames [0 ])
399399 return frames
400400
0 commit comments