Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modules/dataLoader/mixin/DataLoaderMgdsMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _create_mgds(
settings,
definition,
batch_size=config.batch_size, #local batch size
state=PipelineState(config.dataloader_threads),
state=PipelineState(config.caching_threads),
initial_epoch=train_progress.epoch,
initial_epoch_sample=train_progress.epoch_sample,
)
Expand Down
8 changes: 5 additions & 3 deletions modules/trainer/GenericTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.TimeUnit import TimeUnit
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.PrefetchIterator import PrefetchIterator
from modules.util.profiling_util import TorchMemoryRecorder, TorchProfiler
from modules.util.time_util import get_string_timestamp
from modules.util.torch_util import torch_gc
Expand Down Expand Up @@ -679,11 +680,12 @@ def train(self):

current_epoch_length = self.data_loader.get_data_set().approximate_length()

batches = self.data_loader.get_data_loader()
if self.config.prefetch_next_batch:
batches = PrefetchIterator(batches)
if multi.is_master():
batches = step_tqdm = tqdm(self.data_loader.get_data_loader(), desc="step", total=current_epoch_length,
batches = step_tqdm = tqdm(batches, desc="step", total=current_epoch_length,
initial=train_progress.epoch_step)
else:
batches = self.data_loader.get_data_loader()
for batch in batches:
multi.sync_commands(self.commands)
if self.commands.get_stop_command():
Expand Down
50 changes: 28 additions & 22 deletions modules/ui/TrainUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,44 +303,40 @@ def create_general_tab(self, master):
components.time_entry(frame, 8, 3, self.ui_state, "validate_after", "validate_after_unit")

# device
components.label(frame, 10, 0, "Dataloader Threads",
tooltip="Number of threads used for the data loader. Increase if your GPU has room during caching, decrease if it's going out of memory during caching.")
components.entry(frame, 10, 1, self.ui_state, "dataloader_threads", required=True)

components.label(frame, 11, 0, "Train Device",
components.label(frame, 9, 0, "Train Device",
tooltip="The device used for training. Can be \"cuda\", \"cuda:0\", \"cuda:1\" etc. Default:\"cuda\". Must be \"cuda\" for multi-GPU training.")
components.entry(frame, 11, 1, self.ui_state, "train_device", required=True)
components.entry(frame, 9, 1, self.ui_state, "train_device", required=True)

components.label(frame, 12, 0, "Multi-GPU",
components.label(frame, 10, 0, "Multi-GPU",
tooltip="Enable multi-GPU training")
components.switch(frame, 12, 1, self.ui_state, "multi_gpu")
components.label(frame, 12, 2, "Device Indexes",
components.switch(frame, 10, 1, self.ui_state, "multi_gpu")
components.label(frame, 10, 2, "Device Indexes",
tooltip="Multi-GPU: A comma-separated list of device indexes. If empty, all your GPUs are used. With a list such as \"0,1,3,4\" you can omit a GPU, for example an on-board graphics GPU.")
components.entry(frame, 12, 3, self.ui_state, "device_indexes")
components.entry(frame, 10, 3, self.ui_state, "device_indexes")

components.label(frame, 13, 0, "Gradient Reduce Precision",
components.label(frame, 11, 0, "Gradient Reduce Precision",
tooltip="WEIGHT_DTYPE: Reduce gradients between GPUs in your weight data type; can be imprecise, but more efficient than float32\n"
"WEIGHT_DTYPE_STOCHASTIC: Sum up the gradients in your weight data type, but average them in float32 and stochastically round if your weight data type is bfloat16\n"
"FLOAT_32: Reduce gradients in float32\n"
"FLOAT_32_STOCHASTIC: Reduce gradients in float32; use stochastic rounding to bfloat16 if your weight data type is bfloat16",
wide_tooltip=True)
components.options(frame, 13, 1, [str(x) for x in list(GradientReducePrecision)], self.ui_state,
components.options(frame, 11, 1, [str(x) for x in list(GradientReducePrecision)], self.ui_state,
"gradient_reduce_precision")

components.label(frame, 13, 2, "Fused Gradient Reduce",
components.label(frame, 11, 2, "Fused Gradient Reduce",
tooltip="Multi-GPU: Gradient synchronisation during the backward pass. Can be more efficient, especially with Async Gradient Reduce")
components.switch(frame, 13, 3, self.ui_state, "fused_gradient_reduce")
components.switch(frame, 11, 3, self.ui_state, "fused_gradient_reduce")

components.label(frame, 14, 0, "Async Gradient Reduce",
components.label(frame, 12, 0, "Async Gradient Reduce",
tooltip="Multi-GPU: Asynchroniously start the gradient reduce operations during the backward pass. Can be more efficient, but requires some VRAM.")
components.switch(frame, 14, 1, self.ui_state, "async_gradient_reduce")
components.label(frame, 14, 2, "Buffer size (MB)",
components.switch(frame, 12, 1, self.ui_state, "async_gradient_reduce")
components.label(frame, 12, 2, "Buffer size (MB)",
tooltip="Multi-GPU: Maximum VRAM for \"Async Gradient Reduce\", in megabytes. A multiple of this value can be needed if combined with \"Fused Back Pass\" and/or \"Layer offload fraction\"")
components.entry(frame, 14, 3, self.ui_state, "async_gradient_reduce_buffer")
components.entry(frame, 12, 3, self.ui_state, "async_gradient_reduce_buffer")

components.label(frame, 15, 0, "Temp Device",
components.label(frame, 13, 0, "Temp Device",
tooltip="The device used to temporarily offload models while they are not used. Default:\"cpu\"")
components.entry(frame, 15, 1, self.ui_state, "temp_device")
components.entry(frame, 13, 1, self.ui_state, "temp_device")

frame.pack(fill="both", expand=1)
return frame
Expand All @@ -366,10 +362,20 @@ def create_data_tab(self, master):
tooltip="Caching of intermediate training data that can be re-used between epochs")
components.switch(frame, 1, 1, self.ui_state, "latent_caching")

# caching threads
components.label(frame, 2, 0, "Caching Threads",
tooltip="Number of threads used while building the latent and text caches. Increase if your GPU has room during caching, decrease if it's going out of memory during caching. Only affects performance while the cache is being built.")
components.entry(frame, 2, 1, self.ui_state, "caching_threads", width=100, sticky="nw", required=True)

# prefetch next batch
components.label(frame, 3, 0, "Prefetch Next Batch",
tooltip="Load the next batch on a background thread, overlapping disk reads with the current training step. Most beneficial when caching is enabled, since the prefetch thread then only does disk reads. With caching disabled, the text encoder / VAE forward passes run concurrently with training, increasing peak VRAM.")
components.switch(frame, 3, 1, self.ui_state, "prefetch_next_batch")

# clear cache before training
components.label(frame, 2, 0, "Clear cache before training",
components.label(frame, 4, 0, "Clear cache before training",
tooltip="Clears the cache directory before starting to train. Only disable this if you want to continue using the same cached data. Disabling this can lead to errors, if other settings are changed during a restart")
components.switch(frame, 2, 1, self.ui_state, "clear_cache_before_training")
components.switch(frame, 4, 1, self.ui_state, "clear_cache_before_training")

frame.pack(fill="both", expand=1)
return frame
Expand Down
70 changes: 70 additions & 0 deletions modules/util/PrefetchIterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import queue
import threading
from collections.abc import Iterable, Iterator
from contextlib import nullcontext, suppress

import torch


class PrefetchIterator:
"""Iterable wrapper that prefetches items ahead on a single background thread.

Wrapping an iterable in PrefetchIterator lets the producer-side work
(e.g. disk reads, decoding, encoding) overlap with whatever the consumer
is doing between iterations.

The producer runs on a dedicated CUDA stream so tensor uploads to the GPU
don't have to wait for in-flight training work on the default stream.
"""

def __init__(self, iterable: Iterable, queue_size: int = 1, stop_poll_interval: float = 0.1):
self._iterable = iterable
self._queue_size = queue_size
# How often the producer checks the stop signal while blocked on put.
self._stop_poll_interval = stop_poll_interval

def __iter__(self) -> Iterator:
q: queue.Queue = queue.Queue(maxsize=self._queue_size)
stop_event = threading.Event()

stream_ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext()

def put_or_stop(value) -> bool:
# Block on put, but periodically wake to check the stop signal so
# we can exit if the consumer has gone away.
while not stop_event.is_set():
with suppress(queue.Full):
q.put(value, timeout=self._stop_poll_interval)
return True
return False

def producer():
with stream_ctx:
try:
for item in self._iterable:
if not put_or_stop(item):
return
except BaseException as e:
put_or_stop(e)
return
put_or_stop(StopIteration())

t = threading.Thread(target=producer, daemon=True)
t.start()

try:
while True:
item = q.get()
if isinstance(item, StopIteration):
return
if isinstance(item, BaseException):
raise item
yield item
finally:
# Signal the producer to stop and drain anything pending so it
# can wake from a blocked put and observe the stop signal.
stop_event.set()
with suppress(queue.Empty):
while True:
q.get_nowait()
t.join()
17 changes: 14 additions & 3 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ class TrainConfig(BaseConfig):
ema: EMAMode
ema_decay: float
ema_update_step_interval: int
dataloader_threads: int
caching_threads: int
prefetch_next_batch: bool
train_device: str
temp_device: str
train_dtype: DataType
Expand Down Expand Up @@ -569,7 +570,7 @@ class TrainConfig(BaseConfig):
def __init__(self, data: list[(str, Any, type, bool)]):
super().__init__(
data,
config_version=10,
config_version=11,
config_migrations={
0: self.__migration_0,
1: self.__migration_1,
Expand All @@ -581,6 +582,7 @@ def __init__(self, data: list[(str, Any, type, bool)]):
7: self.__migration_7,
8: self.__migration_8,
9: self.__migration_9,
10: self.__migration_10,
}
)

Expand Down Expand Up @@ -800,6 +802,14 @@ def replace_dtype(part: str):

return migrated_data

def __migration_10(self, data: dict) -> dict:
migrated_data = data.copy()

if "dataloader_threads" in migrated_data:
migrated_data["caching_threads"] = migrated_data.pop("dataloader_threads")

return migrated_data

def weight_dtypes(self) -> ModelWeightDtypes:
return ModelWeightDtypes(
self.train_dtype,
Expand Down Expand Up @@ -997,7 +1007,8 @@ def default_values() -> 'TrainConfig':
data.append(("ema", EMAMode.OFF, EMAMode, False))
data.append(("ema_decay", 0.999, float, False))
data.append(("ema_update_step_interval", 5, int, False))
data.append(("dataloader_threads", 2, int, False))
data.append(("caching_threads", 2, int, False))
data.append(("prefetch_next_batch", True, bool, False))
data.append(("train_device", default_device.type, str, False))
data.append(("temp_device", "cpu", str, False))
data.append(("train_dtype", DataType.FLOAT_16, DataType, False))
Expand Down
4 changes: 2 additions & 2 deletions modules/util/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def create_data_loader(
train_progress: TrainProgress | None = None,
is_validation: bool = False
) -> BaseDataLoader | None:
if config.gradient_checkpointing.offload() and config.layer_offload_fraction > 0 and config.dataloader_threads > 1:
raise RuntimeError('layer offloading can not be activated if "dataloader_threads" > 1')
if config.gradient_checkpointing.offload() and config.layer_offload_fraction > 0 and config.caching_threads > 1:
raise RuntimeError('layer offloading can not be activated if "caching_threads" > 1')

if train_progress is None:
train_progress = TrainProgress()
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#chroma Finetune 16GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"resolution": "512",
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.4,
"dataloader_threads": 1,
"caching_threads": 1,
"transformer": {
"train": true,
"weight_dtype": "BFLOAT_16"
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#chroma Finetune 8GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"resolution": "512",
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.85,
"dataloader_threads": 1,
"caching_threads": 1,
"transformer": {
"train": true,
"weight_dtype": "BFLOAT_16"
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#chroma LoRA 8GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"resolution": "512",
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.6,
"dataloader_threads": 1,
"caching_threads": 1,
"transformer": {
"train": true,
"weight_dtype": "FLOAT_8"
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#ernie LoRA 16GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
"layer_filter_preset": "blocks"
},
"timestep_distribution": "LOGIT_NORMAL",
"dataloader_threads": 1
"caching_threads": 1
}
2 changes: 1 addition & 1 deletion training_presets/#ernie LoRA 8GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"layer_filter_preset": "blocks"
},
"timestep_distribution": "LOGIT_NORMAL",
"dataloader_threads": 1,
"caching_threads": 1,
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.7
}
2 changes: 1 addition & 1 deletion training_presets/#flux2 Finetune 16GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
},
"timestep_distribution": "LOGIT_NORMAL",
"dynamic_timestep_shifting": true,
"dataloader_threads": 1,
"caching_threads": 1,
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.6,
"optimizer": {
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#flux2 Finetune 24GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
},
"timestep_distribution": "LOGIT_NORMAL",
"dynamic_timestep_shifting": true,
"dataloader_threads": 1,
"caching_threads": 1,
"optimizer": {
"optimizer": "ADAFACTOR"
},
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#flux2 LoRA 16GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@
},
"timestep_distribution": "LOGIT_NORMAL",
"dynamic_timestep_shifting": true,
"dataloader_threads": 1
"caching_threads": 1
}
2 changes: 1 addition & 1 deletion training_presets/#flux2 LoRA 8GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
},
"timestep_distribution": "LOGIT_NORMAL",
"dynamic_timestep_shifting": true,
"dataloader_threads": 1,
"caching_threads": 1,
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.7
}
2 changes: 1 addition & 1 deletion training_presets/#hidream LoRA.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"batch_size": 4,
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.5,
"dataloader_threads": 1,
"caching_threads": 1,
"learning_rate": 0.0003,
"model_type": "HI_DREAM_FULL",
"output_model_destination": "models/lora.safetensors",
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#hunyuan video LoRA.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"batch_size": 4,
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.5,
"dataloader_threads": 1,
"caching_threads": 1,
"learning_rate": 0.0003,
"model_type": "HUNYUAN_VIDEO",
"output_model_destination": "models/lora.safetensors",
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#qwen Finetune 16GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"resolution": "512",
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.75,
"dataloader_threads": 1,
"caching_threads": 1,
"transformer": {
"train": true,
"weight_dtype": "BFLOAT_16"
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#qwen Finetune 24GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"resolution": "512",
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.55,
"dataloader_threads": 1,
"caching_threads": 1,
"transformer": {
"train": true,
"weight_dtype": "BFLOAT_16"
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#qwen LoRA 16GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"resolution": "512",
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.5,
"dataloader_threads": 1,
"caching_threads": 1,
"transformer": {
"train": true,
"weight_dtype": "FLOAT_8"
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#qwen LoRA 24GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"resolution": "512",
"gradient_checkpointing": "CPU_OFFLOADED",
"layer_offload_fraction": 0.1,
"dataloader_threads": 1,
"caching_threads": 1,
"transformer": {
"train": true,
"weight_dtype": "FLOAT_8"
Expand Down
2 changes: 1 addition & 1 deletion training_presets/#z-image DeTurbo LoRA 16GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@
"layer_filter": "layers",
"layer_filter_preset": "blocks"
},
"dataloader_threads": 1,
"caching_threads": 1,
"timestep_distribution": "LOGIT_NORMAL"
}
2 changes: 1 addition & 1 deletion training_presets/#z-image DeTurbo LoRA 8GB.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
"layer_filter": "layers",
"layer_filter_preset": "blocks"
},
"dataloader_threads": 1,
"caching_threads": 1,
"timestep_distribution": "LOGIT_NORMAL"
}
Loading