Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions modules/dataLoader/ErnieBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.thread_safety import apply_thread_safe_forward
from modules.util.TrainProgress import TrainProgress

from mgds.pipelineModules.DecodeTokens import DecodeTokens
Expand All @@ -32,8 +31,6 @@ def _preparation_modules(self, config: TrainConfig, model: ErnieModel):
image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean')
downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125)
tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=PROMPT_MAX_LENGTH)
if config.dataloader_threads > 1:
apply_thread_safe_forward(model.text_encoder) # workaround for transformers#42673, unclear if Mistral is affected
encode_prompt = EncodeMistralText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask',
text_encoder=model.text_encoder, hidden_state_output_index=HIDDEN_STATES_LAYER, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())

Expand Down
5 changes: 0 additions & 5 deletions modules/dataLoader/Flux2BaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.thread_safety import apply_thread_safe_forward
from modules.util.TrainProgress import TrainProgress

from mgds.pipelineModules.DecodeTokens import DecodeTokens
Expand Down Expand Up @@ -43,8 +42,6 @@ def _preparation_modules(self, config: TrainConfig, model: Flux2Model):
tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=config.text_encoder_sequence_length,
apply_chat_template = lambda caption: mistral_format_input([caption], MISTRAL_SYSTEM_MESSAGE), apply_chat_template_kwargs = {'add_generation_prompt': False},
)
if config.dataloader_threads > 1:
apply_thread_safe_forward(model.text_encoder) # workaround for transformers#42673
encode_prompt = EncodeMistralText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask',
text_encoder=model.text_encoder, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(),
hidden_state_output_index=MISTRAL_HIDDEN_STATES_LAYERS,
Expand All @@ -53,8 +50,6 @@ def _preparation_modules(self, config: TrainConfig, model: Flux2Model):
tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=config.text_encoder_sequence_length,
apply_chat_template = lambda caption: qwen3_format_input(caption), apply_chat_template_kwargs = {'add_generation_prompt': True, 'enable_thinking': False}
)
if config.dataloader_threads > 1:
apply_thread_safe_forward(model.text_encoder) # workaround for transformers#42673
encode_prompt = EncodeQwenText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask',
text_encoder=model.text_encoder, hidden_state_output_index=QWEN3_HIDDEN_STATES_LAYERS, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())

Expand Down
3 changes: 0 additions & 3 deletions modules/dataLoader/ZImageBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.thread_safety import apply_thread_safe_forward
from modules.util.TrainProgress import TrainProgress

from mgds.pipelineModules.DecodeTokens import DecodeTokens
Expand Down Expand Up @@ -38,8 +37,6 @@ def _preparation_modules(self, config: TrainConfig, model: ZImageModel):
tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=PROMPT_MAX_LENGTH,
apply_chat_template = lambda caption: format_input(caption), apply_chat_template_kwargs = {'add_generation_prompt': True, 'enable_thinking': True}
)
if config.dataloader_threads > 1:
apply_thread_safe_forward(model.text_encoder) # workaround for transformers#42673
encode_prompt = EncodeQwenText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask',
text_encoder=model.text_encoder, hidden_state_output_index=-2, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())
prune_masked_tokens = PruneMaskedTokens(tokens_name='tokens', tokens_mask_name='tokens_mask', hidden_state_name='text_encoder_hidden_state')
Expand Down
6 changes: 4 additions & 2 deletions modules/model/ChromaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
class ChromaModel(BaseModel):
# base model data
tokenizer: T5Tokenizer | None
orig_tokenizer: T5Tokenizer | None
noise_scheduler: FlowMatchEulerDiscreteScheduler | None
text_encoder: T5EncoderModel | None
vae: AutoencoderKL | None
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
)

self.tokenizer = None
self.orig_tokenizer = None
self.noise_scheduler = None
self.text_encoder = None
self.vae = None
Expand Down Expand Up @@ -141,13 +143,13 @@ def eval(self):
self.text_encoder.eval()
self.transformer.eval()

def create_pipeline(self) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
return ChromaPipeline(
transformer=self.transformer,
scheduler=self.noise_scheduler,
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
tokenizer=self.orig_tokenizer if use_original_tokenizers else self.tokenizer,
)

def add_text_encoder_embeddings_to_prompt(self, prompt: str) -> str:
Expand Down
10 changes: 7 additions & 3 deletions modules/model/FluxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def __init__(
class FluxModel(BaseModel):
# base model data
tokenizer_1: CLIPTokenizer | None
orig_tokenizer_1: CLIPTokenizer | None
tokenizer_2: T5Tokenizer | None
orig_tokenizer_2: T5Tokenizer | None
noise_scheduler: FlowMatchEulerDiscreteScheduler | None
text_encoder_1: CLIPTextModel | None
text_encoder_2: T5EncoderModel | None
Expand Down Expand Up @@ -86,7 +88,9 @@ def __init__(
)

self.tokenizer_1 = None
self.orig_tokenizer_1 = None
self.tokenizer_2 = None
self.orig_tokenizer_2 = None
self.noise_scheduler = None
self.text_encoder_1 = None
self.text_encoder_2 = None
Expand Down Expand Up @@ -177,15 +181,15 @@ def eval(self):
self.text_encoder_2.eval()
self.transformer.eval()

def create_pipeline(self) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
return FluxPipeline(
transformer=self.transformer,
scheduler=self.noise_scheduler,
vae=self.vae,
text_encoder=self.text_encoder_1,
tokenizer=self.tokenizer_1,
tokenizer=self.orig_tokenizer_1 if use_original_tokenizers else self.tokenizer_1,
text_encoder_2=self.text_encoder_2,
tokenizer_2=self.tokenizer_2,
tokenizer_2=self.orig_tokenizer_2 if use_original_tokenizers else self.tokenizer_2,
)

def add_text_encoder_1_embeddings_to_prompt(self, prompt: str) -> str:
Expand Down
10 changes: 5 additions & 5 deletions modules/model/HiDreamModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,19 @@ def eval(self):
self.text_encoder_4.eval()
self.transformer.eval()

def create_pipeline(self, use_original_modules: bool) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
return HiDreamImagePipeline(
transformer=self.transformer,
scheduler=self.noise_scheduler,
vae=self.vae,
text_encoder=self.text_encoder_1,
tokenizer=self.orig_tokenizer_1 if use_original_modules else self.tokenizer_1,
tokenizer=self.orig_tokenizer_1 if use_original_tokenizers else self.tokenizer_1,
text_encoder_2=self.text_encoder_2,
tokenizer_2=self.orig_tokenizer_2 if use_original_modules else self.tokenizer_2,
tokenizer_2=self.orig_tokenizer_2 if use_original_tokenizers else self.tokenizer_2,
text_encoder_3=self.text_encoder_3,
tokenizer_3=self.orig_tokenizer_3 if use_original_modules else self.tokenizer_3,
tokenizer_3=self.orig_tokenizer_3 if use_original_tokenizers else self.tokenizer_3,
text_encoder_4=self.text_encoder_4,
tokenizer_4=self.orig_tokenizer_4 if use_original_modules else self.tokenizer_4,
tokenizer_4=self.orig_tokenizer_4 if use_original_tokenizers else self.tokenizer_4,
)

def add_text_encoder_1_embeddings_to_prompt(self, prompt: str) -> str:
Expand Down
6 changes: 3 additions & 3 deletions modules/model/HunyuanVideoModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,15 @@ def eval(self):
self.text_encoder_2.eval()
self.transformer.eval()

def create_pipeline(self, use_original_modules: bool) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
return HunyuanVideoPipeline(
transformer=self.transformer,
scheduler=self.noise_scheduler,
vae=self.vae,
text_encoder=self.text_encoder_1,
tokenizer=self.orig_tokenizer_1 if use_original_modules else self.tokenizer_1,
tokenizer=self.orig_tokenizer_1 if use_original_tokenizers else self.tokenizer_1,
text_encoder_2=self.text_encoder_2,
tokenizer_2=self.orig_tokenizer_2 if use_original_modules else self.tokenizer_2,
tokenizer_2=self.orig_tokenizer_2 if use_original_tokenizers else self.tokenizer_2,
)

def add_text_encoder_1_embeddings_to_prompt(self, prompt: str) -> str:
Expand Down
9 changes: 6 additions & 3 deletions modules/model/PixArtAlphaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
class PixArtAlphaModel(BaseModel):
# base model data
tokenizer: T5Tokenizer | None
orig_tokenizer: T5Tokenizer | None
noise_scheduler: DDIMScheduler | None
text_encoder: T5EncoderModel | None
vae: AutoencoderKL | None
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
)

self.tokenizer = None
self.orig_tokenizer = None
self.noise_scheduler = None
self.text_encoder = None
self.vae = None
Expand Down Expand Up @@ -141,19 +143,20 @@ def eval(self):
self.text_encoder.eval()
self.transformer.eval()

def create_pipeline(self) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
tokenizer = self.orig_tokenizer if use_original_tokenizers else self.tokenizer
match self.model_type:
case ModelType.PIXART_ALPHA:
return PixArtAlphaPipeline(
tokenizer=self.tokenizer,
tokenizer=tokenizer,
text_encoder=self.text_encoder,
vae=self.vae,
transformer=self.transformer,
scheduler=self.noise_scheduler,
)
case ModelType.PIXART_SIGMA:
return PixArtSigmaPipeline(
tokenizer=self.tokenizer,
tokenizer=tokenizer,
text_encoder=self.text_encoder,
vae=self.vae,
transformer=self.transformer,
Expand Down
6 changes: 4 additions & 2 deletions modules/model/SanaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
class SanaModel(BaseModel):
# base model data
tokenizer: GemmaTokenizer | None
orig_tokenizer: GemmaTokenizer | None
noise_scheduler: DDIMScheduler | None
text_encoder: Gemma2Model | None
vae: AutoencoderDC | None
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
)

self.tokenizer = None
self.orig_tokenizer = None
self.noise_scheduler = None
self.text_encoder = None
self.vae = None
Expand Down Expand Up @@ -143,9 +145,9 @@ def eval(self):
self.text_encoder.eval()
self.transformer.eval()

def create_pipeline(self) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
return SanaPipeline(
tokenizer=self.tokenizer,
tokenizer=self.orig_tokenizer if use_original_tokenizers else self.tokenizer,
text_encoder=self.text_encoder,
vae=self.vae,
transformer=self.transformer,
Expand Down
14 changes: 10 additions & 4 deletions modules/model/StableDiffusion3Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ def __init__(
class StableDiffusion3Model(BaseModel):
# base model data
tokenizer_1: CLIPTokenizer | None
orig_tokenizer_1: CLIPTokenizer | None
tokenizer_2: CLIPTokenizer | None
orig_tokenizer_2: CLIPTokenizer | None
tokenizer_3: T5Tokenizer | None
orig_tokenizer_3: T5Tokenizer | None
noise_scheduler: FlowMatchEulerDiscreteScheduler | None
text_encoder_1: CLIPTextModelWithProjection | None
text_encoder_2: CLIPTextModelWithProjection | None
Expand Down Expand Up @@ -98,8 +101,11 @@ def __init__(
)

self.tokenizer_1 = None
self.orig_tokenizer_1 = None
self.tokenizer_2 = None
self.orig_tokenizer_2 = None
self.tokenizer_3 = None
self.orig_tokenizer_3 = None
self.noise_scheduler = None
self.text_encoder_1 = None
self.text_encoder_2 = None
Expand Down Expand Up @@ -208,17 +214,17 @@ def eval(self):
self.text_encoder_3.eval()
self.transformer.eval()

def create_pipeline(self) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
return StableDiffusion3Pipeline(
transformer=self.transformer,
scheduler=self.noise_scheduler,
vae=self.vae,
text_encoder=self.text_encoder_1,
tokenizer=self.tokenizer_1,
tokenizer=self.orig_tokenizer_1 if use_original_tokenizers else self.tokenizer_1,
text_encoder_2=self.text_encoder_2,
tokenizer_2=self.tokenizer_2,
tokenizer_2=self.orig_tokenizer_2 if use_original_tokenizers else self.tokenizer_2,
text_encoder_3=self.text_encoder_3,
tokenizer_3=self.tokenizer_3,
tokenizer_3=self.orig_tokenizer_3 if use_original_tokenizers else self.tokenizer_3,
)

def add_text_encoder_1_embeddings_to_prompt(self, prompt: str) -> str:
Expand Down
11 changes: 7 additions & 4 deletions modules/model/StableDiffusionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
class StableDiffusionModel(BaseModel):
# base model data
tokenizer: CLIPTokenizer | None
orig_tokenizer: CLIPTokenizer | None
noise_scheduler: DDIMScheduler | None
text_encoder: CLIPTextModel | None
vae: AutoencoderKL | None
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
)

self.tokenizer = None
self.orig_tokenizer = None
self.noise_scheduler = None
self.text_encoder = None
self.vae = None
Expand Down Expand Up @@ -136,12 +138,13 @@ def eval(self):
self.text_encoder.eval()
self.unet.eval()

def create_pipeline(self) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
tokenizer = self.orig_tokenizer if use_original_tokenizers else self.tokenizer
if self.model_type.has_depth_input():
return StableDiffusionDepth2ImgPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
tokenizer=tokenizer,
unet=self.unet,
scheduler=self.noise_scheduler,
depth_estimator=self.depth_estimator,
Expand All @@ -151,7 +154,7 @@ def create_pipeline(self) -> DiffusionPipeline:
return StableDiffusionInpaintPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
tokenizer=tokenizer,
unet=self.unet,
scheduler=self.noise_scheduler,
safety_checker=None,
Expand All @@ -162,7 +165,7 @@ def create_pipeline(self) -> DiffusionPipeline:
return StableDiffusionPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
tokenizer=tokenizer,
unet=self.unet,
scheduler=self.noise_scheduler,
safety_checker=None,
Expand Down
10 changes: 7 additions & 3 deletions modules/model/StableDiffusionXLModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(
class StableDiffusionXLModel(BaseModel):
# base model data
tokenizer_1: CLIPTokenizer | None
orig_tokenizer_1: CLIPTokenizer | None
tokenizer_2: CLIPTokenizer | None
orig_tokenizer_2: CLIPTokenizer | None
noise_scheduler: DDIMScheduler | None
text_encoder_1: CLIPTextModel | None
text_encoder_2: CLIPTextModelWithProjection | None
Expand Down Expand Up @@ -81,7 +83,9 @@ def __init__(
)

self.tokenizer_1 = None
self.orig_tokenizer_1 = None
self.tokenizer_2 = None
self.orig_tokenizer_2 = None
self.noise_scheduler = None
self.text_encoder_1 = None
self.text_encoder_2 = None
Expand Down Expand Up @@ -166,13 +170,13 @@ def eval(self):
self.text_encoder_2.eval()
self.unet.eval()

def create_pipeline(self) -> DiffusionPipeline:
def create_pipeline(self, use_original_tokenizers: bool = False) -> DiffusionPipeline:
return StableDiffusionXLPipeline(
vae=self.vae,
text_encoder=self.text_encoder_1,
text_encoder_2=self.text_encoder_2,
tokenizer=self.tokenizer_1,
tokenizer_2=self.tokenizer_2,
tokenizer=self.orig_tokenizer_1 if use_original_tokenizers else self.tokenizer_1,
tokenizer_2=self.orig_tokenizer_2 if use_original_tokenizers else self.tokenizer_2,
unet=self.unet,
scheduler=self.noise_scheduler,
)
Expand Down
Loading