Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@ OUT/
examples/experiments/grounded_program_synthesis/dataset
ckpts/

ray_results/
ray_result/
examples/checkpoints/
64 changes: 64 additions & 0 deletions examples/hh/dpo_hh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
import sys

from datasets import load_dataset

import trlx
from trlx.data.default_configs import (
DPOConfig,
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)

default_config = TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=1000,
batch_size=4,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AccelerateDPOTrainer",
checkpoint_dir="checkpoints/dpo_hh",
),
model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps
method=DPOConfig(
name="DPOConfig",
gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True),
beta=0.1,
label_pad_token_id=-100,
padding_value=0,
),
)


def preprocess(sample):
sample["dpo"] = [sample["prompt"], sample["chosen"], sample["rejected"]]
return sample


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess)

trlx.train(
config=config,
samples=dataset["train"]["dpo"],
eval_prompts=dataset["test"]["prompt"][:280],
# metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)},
stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
27 changes: 27 additions & 0 deletions trlx/data/default_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from trlx.models.modeling_ilql import ILQLConfig
from trlx.models.modeling_ppo import PPOConfig
from trlx.trainer.accelerate_dpo_trainer import DPOConfig
from trlx.trainer.accelerate_sft_trainer import SFTConfig

from .configs import (
Expand Down Expand Up @@ -146,3 +147,29 @@ def default_nemo_1_3b_config():

here = Path(__file__).parent
return OmegaConf.load(here.parent.parent / "configs" / "nemo_configs" / "megatron_1.3b.yaml")


def default_dpo_config():
return TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=1000,
batch_size=8,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="DPOTrainer",
),
model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(
name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4) # train.total_steps
),
method=DPOConfig(
name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), beta=0.1
),
)
13 changes: 13 additions & 0 deletions trlx/data/dpo_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass

from transformers import BatchEncoding


@dataclass
class DPOElement:
prompt_tokens: BatchEncoding
chosen_tokens: BatchEncoding
rejected_tokens: BatchEncoding


# TODO: Extend to include a concrete class for DPOPreferenceBatch
4 changes: 2 additions & 2 deletions trlx/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __next__(self): # noqa: C901
minibatch = BatchEncoding(sliced_data)
elif is_dataclass(batch):
minibatch = batch.__class__(**sliced_data)
# else:
# minibatch = sliced_data
else:
minibatch = sliced_data

minibatches.append(minibatch)

Expand Down
121 changes: 121 additions & 0 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PreTrainedTokenizerFast,
)

from trlx.data.dpo_types import DPOElement
from trlx.data.ilql_types import (
ILQLBatch,
ILQLElement,
Expand Down Expand Up @@ -277,3 +278,123 @@ def create_loader(self, batch_size: int):
collate_fn=ilql_seq2seq_collate_fn,
drop_last=torch.distributed.is_initialized(),
)


class DPOStore(BaseRolloutStore):
# Adapted from TRL
def __init__(
self,
preferences: List[DPOElement],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
label_pad_token_id: int,
padding_value: int,
):
super().__init__()
self.tokenizer = tokenizer
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value

self.history = [
self._build_batch_from_preference_tokens(preference_element) for preference_element in preferences
]

@staticmethod
def tokenize_preferences(
sample: Iterable[str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048
) -> DPOElement:
if isinstance(sample, Iterable):
if len(sample) != 3:
raise ValueError(
f"Expected iterable of length 3 (prompt, chosen response, rejected response). Got {len(sample)}"
)
prompt_tokens = tokenizer(sample[0], add_special_tokens=False)
chosen_tokens = tokenizer(sample[1], add_special_tokens=False)
rejected_tokens = tokenizer(sample[2], add_special_tokens=False)
else:
raise ValueError(f"{sample} is not an iterable")

chosen_tokens["input_ids"].append(tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)

rejected_tokens["input_ids"].append(tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

# if combined sequence is too long, truncate the prompt only
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
if tokenizer.truncation_side == "right":
prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()}
elif tokenizer.truncation_side == "left":
prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()}

# if that's still too long, truncate the response
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()}
rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()}

return DPOElement(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens)

def _build_batch_from_preference_tokens(self, preference_tokens: DPOElement) -> Dict:
# Create labels
chosen_sequence_tokens = {
k: preference_tokens.prompt_tokens[k] + preference_tokens.chosen_tokens[k]
for k in preference_tokens.chosen_tokens
}
rejected_sequence_tokens = {
k: preference_tokens.prompt_tokens[k] + preference_tokens.rejected_tokens[k]
for k in preference_tokens.rejected_tokens
}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(preference_tokens.prompt_tokens["input_ids"])] = [
self.label_pad_token_id
] * len(preference_tokens.prompt_tokens["input_ids"])
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(preference_tokens.prompt_tokens["input_ids"])] = [
self.label_pad_token_id
] * len(preference_tokens.prompt_tokens["input_ids"])

batch = {}

for k, toks in {
"chosen": chosen_sequence_tokens,
"rejected": rejected_sequence_tokens,
"prompt": preference_tokens.prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}_{type_key}"] = tokens

return batch

def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:
def collate_fn(batch: Iterable[dict]):
# first, pad everything to the same length
padded_batch = {}
for k in batch[0].keys():
if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
# adapted from https://stackoverflow.com/questions/73256206
if "prompt" in k:
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
else:
to_pad = [torch.LongTensor(ex[k]) for ex in batch]
if k.endswith("_input_ids"):
padding_value = self.tokenizer.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
elif k.endswith("_attention_mask"):
padding_value = self.padding_value
else:
raise ValueError(f"Unexpected key in batch '{k}'")

padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
# for the prompt, flip back so padding is on left side
if "prompt" in k:
padded_batch[k] = padded_batch[k].flip(dims=[1])
else:
padded_batch[k] = [ex[k] for ex in batch]

return padded_batch

return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle)
Loading