generate
理论部分在这:generate相关 ## generate参数
def generate(
self,
= None,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] int, torch.Tensor], List[int]]] = None,
prefix_allowed_tokens_fn: Optional[Callable[[bool] = None,
synced_gpus: Optional["PreTrainedModel"] = None,
assistant_model: Optional["BaseStreamer"] = None,
streamer: Optional[= None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] **kwargs,
-> Union[GenerateOutput, torch.LongTensor]: )
在代码中可以看到在函数入口显式的定义了很多参数。他们的具体含义如下
- inputs:tensor 形式的 token_id,通常先准备文本形式的提示词和输入,使用tokenizer转化为对应 id,这里维度通常为 [batch_size, seq_len]
- generation_config:一个用 GenerationConfig 类创建的对象,存储着模型生成的超参数,可以提前创建该对象并传入 .generate()
- logits_processor:高级功能,logits_processor 可以在每个 step 的输出概率计算完成后,对分数进行进一步的干预,改变输出的概率分布,从而影响生成的结果,例如最常见的,重复惩罚,就是使用 logits_processor 完成的。
- stopping_criteria:高级功能,允许用户通过 stopping_criteria 自定义生成停止条件
- prefix_allowed_tokens_fn:解码策略的一个超参数,用于前缀 token 约束
- synced_gpus:
- DeepSpeed ZeRO Stage-3 多GPU时使用(ZeRO-3包括优化器状态+梯度+权重并行优化,而推理阶段只使用权重并行),此时需要将 synced_gpus 设置成 Ture。.
- 否则,如果一个 GPU 在另一个 GPU 之前完成生成,整个系统就会挂起,因为其余 GPU 尚未从最先完成的 GPU 接收到权重分片。
- transformers>=4.28 在生成时检测到多个 GPU 会自动设置 synced_gpus=True,transformers<4.28 需要手动设置,本文代码环境transformers=4.41.1
- assistant_model:高级功能,辅助生成模型,另一个词表完全相同的小模型,有些token使用辅助模型生成更快
- streamer:流式输出控制器,现在的大模型平台都是一个字一个字显示出来的,这就是流式输出,否则的话会等所有生成完成再显示出来。这个可以自定义流式输出的方式
- negative_prompt_ids:负面提示,一些前沿研究会用到,不用管
- negative_prompt_attention_mask:负面提示的 attention_mask
- **kwargs
- 这里经常传入 temperature=0.7, top_k=20, max_new_tokens=512等参数,都是通过**kwargs传入进来的
- 其实传入的这些都是输入参数 generation_config 的属性(可以进入对应类中看一下有哪些属性,from transformers.generation.configuration_utils import GenerationConfig),你可以创建该对象并覆盖某些参数,也可以通过参数形式在调用.generate()时传进来
- 在后面会将传入的这些参数覆盖掉generation_config中对应的属性
下面只说明一些关键的地方 ## kwargs -> generation_config
就是将kwargs中传入的kwargs的参数变成config。
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
准备logit处理器
= self._get_logits_processor(
prepared_logits_processor =generation_config,
generation_config=input_ids_length,
input_ids_seq_length=inputs_tensor,
encoder_input_ids=prefix_allowed_tokens_fn,
prefix_allowed_tokens_fn=logits_processor,
logits_processor=inputs_tensor.device,
device=model_kwargs,
model_kwargs=negative_prompt_ids,
negative_prompt_ids=negative_prompt_attention_mask,
negative_prompt_attention_mask )
就是将generation_config中的采样参数封装成logit-processor,还有自己定义的processor
准备stopping处理器
= self._get_stopping_criteria(
prepared_stopping_criteria =generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
generation_config )
同理。将一些与停止有关的参数封装成stopping处理器。
logits warper
logits warper
里面是采样时才需要运行的处理器logits processor
是通用的处理器,每种生成模式都需要用到的
= (
prepared_logits_warper self._get_logits_warper(generation_config) if generation_config.do_sample else None
)
正式生成
# 进入模型内部生成下一个token
= self(
outputs **model_inputs,
=True,
return_dict=output_attentions,
output_attentions=output_hidden_states,
output_hidden_states
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
# 取出最后一个token,.logits维度为(batch_size, seq_len, vocab_size)
= outputs.logits[:, -1, :]
next_token_logits
# 经过前面的处理器进行分数调整
= logits_processor(input_ids, next_token_logits)
next_token_scores if do_sample:
= logits_warper(input_ids, next_token_scores) next_token_scores
按照是否采样来生成下一个token:
if do_sample:
= nn.functional.softmax(next_token_scores, dim=-1)
probs # torch.multinomial:按照输入probs的每一行(每个batch)作为采样的概率,
# 每行不放回的取出num_samples个,随机采样每个batch按输入概率取出一个
= torch.multinomial(probs, num_samples=1).squeeze(1)
next_tokens else:
# torch.argmax取出输入next_token_scores中值最大的索引
= torch.argmax(next_token_scores, dim=-1) next_tokens
最后判断是否可以停止:
= unfinished_sequences & ~stopping_criteria(input_ids, scores)
unfinished_sequences = unfinished_sequences.max() == 0 this_peer_finished
参考
https://blog.csdn.net/qq_41496421/article/details/142346738?spm=1001.2014.3001.5502 https://blog.csdn.net/qq_41496421/article/details/142580960?spm=1001.2014.3001.5501