-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathinfer_multishot.py
More file actions
92 lines (77 loc) · 3.93 KB
/
infer_multishot.py
File metadata and controls
92 lines (77 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from diffsynth.utils.data import save_video
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
import pandas as pd
import ast
import json
import os
import argparse
from util import rgb_to_latent_shot_groups_list, pad_shot_groups_to_4n_plus_1, get_user_wanted_frames, save_video_with_caption
import random, sys
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--test_csv_path", type=str, default=None)
parser.add_argument("--model_path_json", type=str, default=None)
parser.add_argument("--output_name", type=str, default="1.3b")
parser.add_argument("--target_width", type=int, default=832)
parser.add_argument("--target_height", type=int, default=480)
parser.add_argument("--use_usp", type=bool, default=False)
args = parser.parse_args()
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
# init
load_path_json = args.model_path_json
with open(load_path_json) as user_file:
model_paths = json.load(user_file)
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
use_usp=args.use_usp, # if use usp for inference
model_configs=[
ModelConfig(path=model_paths["dit"]),
ModelConfig(path=model_paths["t5"]),
ModelConfig(path=model_paths["vae"]),
],
tokenizer_config=ModelConfig(path=model_paths["tokenizer"]),
)
# read test cases
df = pd.read_csv(args.test_csv_path)
output_dir = f"output/{args.output_name}"
print(f"output_dir: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
for index, row in df.iterrows():
# do not use random seed when use_usp=True
seed = 42 # random.randint(0, sys.maxsize)
# read shot groups and transform to latent shot groups
shot_groups = ast.literal_eval(row["shot_groups"])
padded_shot_groups, save_shot_num_list = pad_shot_groups_to_4n_plus_1(shot_groups)
latent_shot_groups = rgb_to_latent_shot_groups_list(padded_shot_groups)
# read multi-shot captions
with open(row["gemini_caption"]) as user_file:
caption_dict = json.load(user_file)
num_shots = len(latent_shot_groups)
# global caption
global_caption = f"Story: {caption_dict['global_caption']} "
# per-shot captions
now_multishot_video_caption_list = []
for count in range(num_shots):
now_shot_caption = global_caption + f"Now: {caption_dict[f'shot{count}']}"
now_multishot_video_caption_list.append(now_shot_caption)
multishot_negative_prompt = [negative_prompt] * num_shots
# Text-to-video
video = pipe(
width=args.target_width,
height=args.target_height,
prompt=now_multishot_video_caption_list,
negative_prompt=multishot_negative_prompt,
seed=seed, tiled=True,
num_frames=latent_shot_groups[-1][-1],
shot_groups=shot_groups,
latent_shot_groups=latent_shot_groups,
)
user_wanted_frames = get_user_wanted_frames(video, padded_shot_groups, save_shot_num_list)
# save video without caption
save_video(user_wanted_frames, f"{output_dir}/{index}.mp4", fps=15, quality=5)
# save video with caption
save_video_with_caption(num_shots, shot_groups, now_multishot_video_caption_list, user_wanted_frames, f"{output_dir}/{index}_with_caption.mp4", args.target_width)
print("Enjoy the story")
torch.cuda.empty_cache()