Dataset
详细讲解一下verl中的RLHFDataset,它继承自torch的Dataset,需要实现__getitem__来返回数据。
初始化
class RLHFDataset(Dataset):
"""
Load and preprocess RLHF data from Parquet files.
- Caches files locally.
- Reads into a HuggingFace Dataset and tokenizes prompts.
- Optionally handles images/videos via a ProcessorMixin.
- Filters prompts over a max length.
- Supports resuming from checkpoints.
Args:
data_files (str or list): Path(s) to Parquet file(s).
tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.
config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.
processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.
"""
def __init__(
self,
str | list[str],
data_files:
tokenizer: PreTrainedTokenizer,
config: DictConfig,= None,
processor: Optional[ProcessorMixin]
):if not isinstance(data_files, list | ListConfig):
= [data_files]
data_files
self.data_files = copy.deepcopy(data_files)
self.original_data_files = copy.deepcopy(data_files) # use for resume
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
self.prompt_key = config.get("prompt_key", "prompt")
self.image_key = config.get("image_key", "images")
self.video_key = config.get("video_key", "videos")
self.max_prompt_length = config.get("max_prompt_length", 1024)
self.return_raw_chat = config.get("return_raw_chat", False)
self.return_full_prompt = config.get("return_full_prompt", False)
self.truncation = config.get("truncation", "error")
self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
self.num_workers = min(self.num_workers, os.cpu_count())
self.use_shm = config.get("use_shm", False)
self.chat_template_func = config.get("chat_template_func", None)
self.need_tools_kwargs = config.get("need_tools_kwargs", False)
self.filter_prompts = config.get("filter_prompts", True)
self.serialize_dataset = False
self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True)
self._download()
self._read_files_and_tokenize()
(来自deepwiki)
download方法就是把hdfs文件或者本地文件缓存到缓存路径下。
接下来重点看一下read_files_and_tokenize方法:
def _read_files_and_tokenize(self):
= []
dataframes for parquet_file in self.data_files:
# read parquet files and cache
= datasets.load_dataset("parquet", data_files=parquet_file)["train"]
dataframe
dataframes.append(dataframe)self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
print(f"dataset len: {len(self.dataframe)}")
self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)
def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None):
# filter out too long prompts
if self.filter_overlong_prompts:
= self.tokenizer
tokenizer = self.processor
processor = self.prompt_key
prompt_key = self.image_key
image_key = self.video_key
video_key
if processor is not None:
from verl.utils.dataset.vision_utils import process_image, process_video
def doc2len(doc) -> int:
= self._build_messages(doc)
messages = self.processor.apply_chat_template(
raw_prompt =True, tokenize=False
messages, add_generation_prompt
)= [process_image(image) for image in doc[image_key]] if image_key in doc else None
images = [process_video(video) for video in doc[video_key]] if video_key in doc else None
videos
return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0])
else:
def doc2len(doc) -> int:
return len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True))
= dataframe.filter(
dataframe lambda doc: doc2len(doc) <= self.max_prompt_length,
=self.num_workers,
num_proc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
desc
)
print(f"filter dataset len: {len(dataframe)}")
return dataframe
实现了读取parquet文件,然后再根据传入的prompt_length筛选掉prompt长度超过length的样本。返回的是dataframe。
然后就到了最重要的__getitem__方法,来构造我们需要的数据:
def __getitem__(self, item):
"""
Note that we also return the raw_input_ids so that it can be combined with other chat template
"""
dict = self.dataframe[item]
row_dict: = self._build_messages(row_dict)
messages = {}
model_inputs
if self.processor is not None:
from verl.utils.dataset.vision_utils import process_image, process_video
= self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
raw_prompt = {}
multi_modal_data
= None
images if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None:
= [process_image(image) for image in row_dict.pop(self.image_key)]
images
# due to the image key is "image" instead of "images" in vllm, we need to use "image" here
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
"image"] = images
multi_modal_data[
= None
videos if self.video_key in row_dict and row_dict.get(self.video_key, None) is not None:
= [process_video(video) for video in row_dict.pop(self.video_key)]
videos
# due to the video key is "video" instead of "videos" in vllm, we need to use "video" here
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
"video"] = [video.numpy() for video in videos]
multi_modal_data[
= self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")
model_inputs
= model_inputs.pop("input_ids")
input_ids = model_inputs.pop("attention_mask")
attention_mask
if "second_per_grid_ts" in model_inputs:
"second_per_grid_ts")
model_inputs.pop(
# There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
"multi_modal_data"] = multi_modal_data
row_dict[
# We will do batch.union() in the trainer,
# so we cannot have "multi_modal_inputs" in row_dict if rollout generates new multi_modal_inputs
if self.return_multi_modal_inputs:
"multi_modal_inputs"] = dict(model_inputs)
row_dict[
# second_per_grid_ts isn't used for training, just for mrope
"multi_modal_inputs"].pop("second_per_grid_ts", None)
row_dict[
else:
= self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
raw_prompt = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
model_inputs = model_inputs.pop("input_ids")
input_ids = model_inputs.pop("attention_mask")
attention_mask
= verl_F.postprocess_data(
input_ids, attention_mask =input_ids,
input_ids=attention_mask,
attention_mask=self.max_prompt_length,
max_length=self.tokenizer.pad_token_id,
pad_token_id=True,
left_pad=self.truncation,
truncation
)
if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
from verl.models.transformers.qwen2_vl import get_rope_index
= [
position_ids
get_rope_index(self.processor,
=input_ids[0],
input_ids=model_inputs.get("image_grid_thw"),
image_grid_thw=model_inputs.get("video_grid_thw"),
video_grid_thw=model_inputs.get("second_per_grid_ts"),
second_per_grid_ts=attention_mask[0],
attention_mask
)# (1, 3, seq_len)
]
else:
= compute_position_id_with_mask(attention_mask)
position_ids
"input_ids"] = input_ids[0]
row_dict["attention_mask"] = attention_mask[0]
row_dict["position_ids"] = position_ids[0]
row_dict[
= self.tokenizer.encode(raw_prompt, add_special_tokens=False)
raw_prompt_ids if len(raw_prompt_ids) > self.max_prompt_length:
if self.truncation == "left":
= raw_prompt_ids[-self.max_prompt_length :]
raw_prompt_ids elif self.truncation == "right":
= raw_prompt_ids[: self.max_prompt_length]
raw_prompt_ids elif self.truncation == "middle":
= self.max_prompt_length // 2
left_half = self.max_prompt_length - left_half
right_half = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
raw_prompt_ids elif self.truncation == "error":
raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
"raw_prompt_ids"] = raw_prompt_ids
row_dict[# encode prompts without chat template
if self.return_raw_chat:
"raw_prompt"] = messages
row_dict[
# get prompts with chat template
if self.return_full_prompt:
"full_prompts"] = raw_prompt # array of strings
row_dict[
# add index for each prompt
= row_dict.get("extra_info", {}).get("index", 0)
index = row_dict.get("extra_info", {}).get("tools_kwargs", {})
tools_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {})
interaction_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
need_tools_kwargs if need_tools_kwargs and not tools_kwargs:
"tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
logger.warning("index"] = index
row_dict["tools_kwargs"] = tools_kwargs
row_dict["interaction_kwargs"] = interaction_kwargs
row_dict[return row_dict
对于纯文本数据,build_messages方法就是获取prompt_key对应的内容,对于多模态的处理,先不做讲解。processor就是处理多模态数据的,跳到else部分,就是对message调用apply_chat_template,这就要求我们的数据文件需要先处理成这样的格式,详见verl/examples/data_preprocess中的处理。 将raw_prompts保存起来,然后再进行tokenize,postprocess_data是对其进行padding和截断。
def postprocess_data(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,int,
max_length: int,
pad_token_id: =True,
left_pad="error",
truncation
):"""Process tokenizer outputs to consistent shapes via padding/truncation.
Args:
input_ids: Token indices [batch_size, seq_len]
attention_mask: Mask [batch_size, seq_len]
max_length: Target sequence length
pad_token_id: Padding token ID
left_pad: Pad left if True
truncation: "left", "right", "middle" or "error"
Returns:
(input_ids, attention_mask) padded/truncated to max_length
"""
assert truncation in ["left", "right", "middle", "error"]
assert input_ids.ndim == 2
= input_ids.shape[-1]
sequence_length if sequence_length < max_length:
= pad_sequence_to_length(
input_ids =max_length, pad_token_id=pad_token_id, left_pad=left_pad
input_ids, max_seq_len
)= pad_sequence_to_length(
attention_mask =max_length, pad_token_id=0, left_pad=left_pad
attention_mask, max_seq_len
)elif sequence_length > max_length:
if truncation == "left":
# actually, left truncation may not be reasonable
= input_ids[:, -max_length:]
input_ids = attention_mask[:, -max_length:]
attention_mask elif truncation == "right":
= input_ids[:, :max_length]
input_ids = attention_mask[:, :max_length]
attention_mask elif truncation == "middle":
= max_length // 2
left_half = max_length - left_half
right_half = torch.cat([input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1)
input_ids = torch.cat([attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1)
attention_mask elif truncation == "error":
raise NotImplementedError(f"{sequence_length=} is larger than {max_length=}")
else:
raise NotImplementedError(f"Unknown truncation method {truncation}")
return input_ids, attention_mask
根据传入的不同的截断方式,也有不同的处理。
然后我们再根据pad或者截断后的mask计算position_ids。
接下来的操作主要是为了保存raw_prompt和raw_prompt_ids,raw_prompt是messages进行chat_template之后的变量,最后再将extra_info中的一些信息存入到最终要返回的row_dict中返回。