Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions large_language_model_pretraining/nemo/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
micro_batch_size,
sequence_length,
init_global_step,
eval_every,
configs={}
):
mllogger.event(key=constants.CACHE_CLEAR, value=True)
Expand All @@ -169,6 +170,7 @@ def __init__(
self.gbs = global_batch_size
self.mbs = micro_batch_size
self.seq_len = sequence_length
self.eval_every = eval_every

self.is_target_reached = False
self.status = constants.ABORTED
Expand All @@ -185,7 +187,6 @@ def set_success_status(self):
def on_train_epoch_start(self, trainer, pl_module):
mllogger.start(key=constants.EPOCH_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})

return super().on_train_epoch_start(trainer, pl_module)

@rank_zero_only
Expand All @@ -201,9 +202,14 @@ def on_train_end(self, trainer, pl_module):
return super().on_train_end(trainer, pl_module)

@rank_zero_only
def on_validation_start(self, trainer, pl_module):
def log_eval_start(self, trainer, pl_module):
mllogger.end(key=constants.BLOCK_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
mllogger.start(key=constants.EVAL_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})

def on_validation_start(self, trainer, pl_module):
trainer.val_check_interval = self.eval_every
trainer.val_check_batch = self.eval_every
self.log_eval_start(trainer, pl_module)
return super().on_validation_start(trainer, pl_module)

def on_validation_end(self, trainer, pl_module):
Expand Down
4 changes: 3 additions & 1 deletion large_language_model_pretraining/nemo/config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export MBS=1
# If an empty string is provided (""), then the training will continue until time limit
# If we want to save a checkpoint, then this value must be set
export MAX_STEPS=""
export EVAL_EVERY="18432"

# Experiment: starting steps
# This is the starting "offset" step from the checkpoint.
Expand All @@ -92,4 +93,5 @@ export NPAR=1
# Experiment manager: provides seeds to the launched experiments, use space as delimiter, such as "1234 1235 1236"
# The training script will discard all excessive seeds, and generate seeds if given seeds < NEXP.
# To preserve randomness, we recommend not to set this value so that each time seeds can be randomly generated.
export SEEDS=""
export SEEDS=""

9 changes: 7 additions & 2 deletions large_language_model_pretraining/nemo/pretrain_llama31.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nemo.collections.llm.gpt.data import build_pretraining_datamodule
from callbacks import PreemptiveStop, MLPerfCallback, MetricsLogger


def slurm_executor(
user: str,
host: str,
Expand Down Expand Up @@ -93,6 +94,7 @@ def get_pretrain(
nnodes: int,
ngpus_per_node: int,
data_module: run.Config,
start_eval_at: Optional[int]=None,
eval_every: Optional[int]=None,
eval_batches: Optional[int]=None,
) -> run.Partial:
Expand Down Expand Up @@ -180,7 +182,7 @@ def get_pretrain(
pretrain.trainer.max_steps = math.ceil(max_tokens / 8192 / gbs)

pretrain.data = data_module
pretrain.trainer.val_check_interval = eval_every
pretrain.trainer.val_check_interval = start_eval_at
pretrain.trainer.limit_val_batches = eval_batches
pretrain.trainer.limit_test_batches = eval_batches

Expand Down Expand Up @@ -300,7 +302,7 @@ def get_parser() -> argparse.ArgumentParser:

data_group.add_argument("--gbs", type=int, default=1152, help="Global batch size, should be divisible by PP")
data_group.add_argument("--mbs", type=int, default=1, help="Micro batch size")
data_group.add_argument("--eval_every", type=int, default=46080, help="Evaluate at least every N training sequences")
data_group.add_argument("--eval_every", type=int, default=18432, help="Evaluate at least every N training sequences")
data_group.add_argument("--eval_tokens", type=int, default=5760, help="Evaluate using at least N evaluation sequences")
data_group.add_argument('--max_steps', type=int, default=None, help="Maximum number of steps that each experiment partition will train on. None means no restriction on max steps. ")
data_group.add_argument("--use_full_dataset", action="store_true", help="If set, then we use the full dataset, instead of the last 256/1024 shards")
Expand Down Expand Up @@ -352,6 +354,7 @@ def get_parser() -> argparse.ArgumentParser:
use_full_dataset=args.use_full_dataset,
)

start_eval_at = int(args.eval_every) * math.floor(0.0026 * args.gbs + 12) / args.gbs
eval_every_n_batches = math.ceil(args.eval_every / (args.gbs))
eval_batches = math.ceil(args.eval_tokens / (args.gbs))

Expand All @@ -360,6 +363,7 @@ def get_parser() -> argparse.ArgumentParser:
nnodes=args.nodes,
ngpus_per_node=args.gpus_per_node,
data_module=data,
start_eval_at=start_eval_at,
eval_every=eval_every_n_batches,
eval_batches=eval_batches,
)
Expand Down Expand Up @@ -497,6 +501,7 @@ def get_parser() -> argparse.ArgumentParser:
micro_batch_size=args.mbs,
sequence_length=8192,
init_global_step=start_step,
eval_every=eval_every_n_batches,
configs=configs,
),
]
Expand Down
1 change: 1 addition & 0 deletions large_language_model_pretraining/nemo/run_llama31.sh
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ python3 pretrain_llama31.py \
--step_time_atol $STEP_TIME_ATOL \
--ckpt_start_step $START_STEPS \
--max_retries $MAX_RETRIES \
--eval_every $EVAL_EVERY \
$CMD_SUFFIX