当谈到文本生成时,Transformer API是目前最受欢迎的NLP工具之一。 它提供了各种解码策略和参数,使用户可以自定义生成的文本。在本文中,我们将学习如何使用Transformer API生成文本。
基本使用
在使用Transformer API之前,需要安装PyTorch和Transformers包:
1 | $ pip install torch transformers |
完成安装后,可以使用以下代码导入所需的模块:
1 | from transformers import pipeline, set_seed |
其中pipeline
模块提供了生成文本所需的所有功能,而set_seed
允许我们设置随机种子以获得可重复的结果。
以下是一段文本生成的例子:
1 | # 设置随机种子以获得可重复的结果 |
在上述代码中,set_seed
函数设置了随机种子为42
以获得可重复的结果。pipeline
模块加载了一个文本生成器,并指定使用的模型为GPT-2
。调用generator
的方法生成文本,指定了一个起始的文本"The quick brown fox"
,限制了生成文本的最大长度为50个字符,同时指定了生成1个文本序列。最后,打印了生成的文本。
需要注意的是,文本生成是一项计算密集型任务,因此需要具有一定的计算资源。生成更长的文本,或者生成更多的文本序列,可能需要更强大的计算资源。
解码策略
Hugging Face的Transformer API提供了多种解码策略来满足不同的生成需求。
Greedy Decoding
Greedy Decoding (贪心解码) 是最简单的解码策略之一。 它在每个时间步选择概率最高的标记作为生成的标记。 可以通过在generate
函数中设置参数num_beams = 1
和do_sample = False
来使用此策略。 以下是示例代码:
1 | generator = pipeline('text-generation', model='your-model-name') |
Multinomial Sampling
Multinomial Sampling(多项式采样)解码策略是一种随机策略。 它在每个时间步根据标记的概率分布随机采样一个标记作为生成的标记。 可以通过在generate
函数中设置参数num_beams = 1
和do_sample = True
来使用此策略。 以下是示例代码:
1 | generator = pipeline('text-generation', model='your-model-name') |
Beam Search Decoding
Beam Search(束搜索)解码策略是一种广泛使用的解码策略。 它在每个时间步选择最高的k个标记,并计算每个候选标记的概率分布。 然后,它选择概率最高的k个标记作为生成的标记,并将它们作为下一个时间步的候选标记。 可以通过在generate
函数中设置参数num_beams > 1
和do_sample = False
来使用此策略。 以下是示例代码:
1 | generator = pipeline('text-generation', model='your-model-name') |
Beam Search with Multinomial Sampling
Beam Search with Multinomial Sampling(束搜索多项式采样)解码策略结合了束搜索和多项式采样两种解码策略的优点。 它在每个时间步选择最高的k个标记,并从这些标记中根据它们的概率分布随机采样一个标记作为生成的标记。 可以通过在generate
函数中设置参数num_beams > 1
和do_sample = True
来使用此策略。 以下是示例代码:
1 | generator = pipeline('text-generation', model='your-model-name') |
Contrastive Decoding
Contrastive Decoding(对比搜索)解码策略是一种在生成过程中考虑全局最优解的策略。 它在每个时间步选择概率分布最高的k个标记,并根据其频率分布计算每个候选标记的分数,考虑所有以前生成的标记。然后,它选择分数最高的标记作为生成的标记,并将其添加到先前生成的标记中。可以通过在generate
函数中设置参数penalty_alpha > 0
和top_k > 1
来使用此策略。 以下是示例代码:
1 | generator = pipeline('text-generation', model='your-model-name') |
Group Beam Search
Group Beam Search(多样束搜索)解码策略是一种使用多个束搜索进行生成的策略。 它将所有的束搜索分成多个束组,并在所有束搜索中轮流采样。可以通过在generate
函数中设置参数num_beams > 1
和num_beam_groups > 1
来使用此策略。 以下是示例代码:
1 | generator = pipeline('text-generation', model='your-model-name') |
Constrained Decoding
Constrained Decoding(约束搜索)解码策略是一种基于约束条件的生成策略。 它允许用户设置一个约束集合,这些约束集合可以是必须包含的单词或者不能包含的单词。 约束搜索可以使用beam search策略进行生成,也可以与多项式采样策略结合使用。可以通过在generate
函数中设置参数constraints != None
或force_words_ids != None
来使用此策略。 以下是示例代码:
1 | generator = pipeline('text-generation', model='your-model-name') |
解码参数
transformers.generation.GenerationConfig
用于生成文本的任务配置,用户可以根据具体的生成任务灵活配置参数,例如生成文本的最大长度、生成文本的最小长度、生成文本的随机程度、采样方式、beam搜索宽度等等。参数包括以下几种:
- 控制输出长度的参数
这些参数可以控制生成的文本或序列的长度。例如,可以设置生成文本的最大长度或最小长度。 - 控制生成策略的参数
这些参数可以控制生成文本或序列的策略,例如生成的温度或者采样方法。 - 操纵模型输出logits的参数
这些参数可以控制生成的文本或序列的质量,例如在生成过程中惩罚重复出现的单词或者降低生成文本的噪声。 - 定义
generate
的输出变量的参数
这些参数可以定义生成文本或序列的输出变量,例如生成的文本的格式或者生成的序列的标识符。 - 可以在生成时使用的特殊标记
这些参数可以在生成文本或序列时使用特殊的标记,例如起始标记或结束标记。 - 仅适用于编码器-解码器模型的生成参数
这些参数可以控制编码器-解码器模型的生成过程,例如beam search的宽度或者长度惩罚。 - 通配符
这些参数可以使用通配符来代替一些特定的值,例如使用*
代替一个单词或一个字符。
可以根据需求选择不同的参数组合来实现不同的解码策略。例如,设置 do_sample=True
、temperature=0.7
和 top_k=0
可以使用 top-p sampling 策略,生成更多的多样性文本;设置 num_beams=5
和 length_penalty=0.8
可以使用 beam search 策略,生成更流畅的文本。各解码策略与参数设置关系如下:
模式 | num_beams: int | num_beam_groups: int | do_sample: bool | temperature: float | top_k: int | top_p: float | penalty_alpha: float | length_penalty: float | repetition_penalty: float |
---|---|---|---|---|---|---|---|---|---|
greedy | 1 | 1 | F | - | - | - | - | - | - |
sample | 1 | 1 | T | > 0 | > 0 | > 0 | - | - | > 0 |
beam | > 1 | 1 | F | - | > 0 | - | - | > 0 | > 0 |
beam sample | > 1 | 1 | T | > 0 | > 0 | > 0 | - | > 0 | > 0 |
group beam | > 1 | > 1 | F | - | > 0 | - | > 0 | > 0 | > 0 |
其中,-
表示该参数在该解码策略中不适用,> 0
表示该参数必须为大于0的值。需要注意的是,表格中列出的参数不是所有可能的参数,而只是最常用的参数。如果需要使用其他参数,可以查阅相关文档。
高阶用法
LogitsProcessor
LogitsProcessor
是用于在生成文本之前处理模型生成的 logits 的基类。LogitsProcessor
可以在生成过程中修改模型的输出,以产生更好的生成结果。
在 generate
函数中,可以使用 LogitsProcessorList
类来实例化多个 LogitsProcessor
对象,以便在生成文本之前对 logits 进行多个处理;可以将 LogitsProcessorList
对象传递给 logits_processor
参数,以便在生成文本之前对 logits 进行多个处理。
以下是 LogitsProcessor
子类:
MinLengthLogitsProcessor
: 用于确保生成的文本长度达到指定的最小值。RepetitionPenaltyLogitsProcessor
: 通过对之前生成的 token 进行惩罚来减少重复的 token。NoRepeatNGramLogitsProcessor
: 用于确保生成的文本中不包含指定长度的 n-gram 重复。EncoderNoRepeatNGramLogitsProcessor
: 与NoRepeatNGramLogitsProcessor
类似,但是只考虑编码器生成的 token。NoBadWordsLogitsProcessor
: 用于过滤生成的文本中包含不良词汇的情况。PrefixConstrainedLogitsProcessor
: 用于确保生成的文本以指定的前缀开头。HammingDiversityLogitsProcessor
: 通过对生成的 token 序列之间的哈明距离进行惩罚,以增加文本的多样性。ForcedBOSTokenLogitsProcessor
: 用于确保生成的文本以指定的起始标记(例如<s>
)开头。ForcedEOSTokenLogitsProcessor
: 用于确保生成的文本以指定的结束标记(例如</s>
)结尾。InfNanRemoveLogitsProcessor
: 用于过滤生成的文本中包含NaN
或Inf
值的情况。
每个 LogitsProcessor
子类必须实现 __call__
方法,该方法接受两个参数:input_ids
和 logits。input_ids
是用于生成文本的输入序列,而 logits 是模型输出的 logits 张量。__call__
方法必须返回一个元组,其中第一个元素是修改后的 logits 张量,第二个元素是一个布尔值,指示是否应中断生成过程。如果 should_stop
为 True
,则生成过程将提前结束。
这些 LogitsProcessor
子类可以单独使用,也可以与其他 LogitsProcessor
子类一起使用。在使用 LogitsProcessor
时,需要根据生成任务和需求选择适当的子类来处理 logits,以获得更好的生成结果。
StoppingCriteria
StoppingCriteria
是一个用于控制生成过程停止的类。在文本生成任务中,由于生成文本长度不确定,因此需要设定一些停止条件,以避免生成无限长的文本,常用属性和方法为:
max_length
: 最大文本长度,超过该长度后停止生成。max_time
: 最大生成时间,超过该时间后停止生成。stop
: 布尔值,指示是否停止生成。is_done
: 布尔值,指示生成是否已完成。update
: 更新生成状态,包括生成长度和时间,并检查是否需要停止生成。
在使用 StoppingCriteria
时,可以根据生成任务和需求设定适当的停止条件。例如,在生成摘要时,可以根据原始文本的长度和要求的摘要长度来设定最大文本长度;在生成对话时,可以根据时间或者回合数来设定最大生成时间。通过合理设置停止条件,可以有效地控制生成的结果,避免无限生成或生成不满足需求的文本。
以下是各类文本生成任务中停止条件的具体实现:
MaxLengthCriteria
:根据设定的最大文本长度,在生成文本的过程中,当生成的文本长度超过设定的最大文本长度时,停止生成。MaxNewTokensCriteria
:根据设定的最大新增 token 数量,在生成文本的过程中,当生成的文本新增的 token 数量超过设定的最大新增 token 数量时,停止生成。这个停止条件更适合生成任务中需要控制每次迭代生成的长度,而不是总长度的情况。MaxTimeCriteria
:根据设定的最大生成时间,在生成文本的过程中,当生成文本的用时超过设定的最大生成时间时,停止生成。
LogitsWarper
LogitsWarper
是一个用于修正模型预测结果的类,可以在模型输出 logits 后对其进行操作,以达到一定的效果。如,可以实现以下一些常见的操作:
top_k_warp
: 对 logits 进行 top-k 截断,只保留前 k 个最大值,并将其他值设为负无穷。top_p_warp
: 对 logits 进行 top-p 截断,只保留累计概率大于等于 p 的 tokens,将其他值设为负无穷。temperature_warp
: 对 logits 进行温度缩放,调整模型的生成多样性,即通过降低温度(temperature)来减少随机性,提高预测的准确性;或者通过提高温度来增加随机性,增加生成的多样性。
在使用 LogitsWarper
时,需要根据生成任务和需求选择适当的操作方法,并设置合适的参数,以达到期望的效果。例如,在生成文本时,可以通过 top-k 截断或者 top-p 截断来控制生成的多样性和准确性;或者通过温度缩放来调整生成的多样性。
TemperatureLogitsWarper
、TopPLogitsWarper
和 TopKLogitsWarper
都是 LogitsWarper
的具体实现,分别实现了不同的操作方法。
TemperatureLogitsWarper
: 对 logits 进行温度缩放操作。温度缩放是通过调整 softmax 分布的温度参数来控制生成的多样性。当温度较高时,生成的样本将更加随机,具有更大的多样性,但可能会出现较多的错误;当温度较低时,生成的样本将更加准确,但可能缺乏多样性。TemperatureLogitsWarper 通过对 logits 进行温度缩放来实现多样性和准确性之间的平衡。TopPLogitsWarper
: 对 logits 进行 top-p 截断操作。top-p 截断是指在 softmax 分布中,保留累计概率大于等于 p 的 tokens,将其他值设为负无穷。通过调整 p 的值,可以控制生成样本的多样性和准确性。当 p 较大时,生成的样本具有更多的多样性,但可能出现较多的错误;当 p 较小时,生成的样本更加准确,但可能缺乏多样性。TopPLogitsWarper 通过对 logits 进行 top-p 截断来实现多样性和准确性之间的平衡。
TopKLogitsWarper
: 对 logits 进行 top-k 截断操作。top-k 截断是指在 softmax 分布中,保留前 k 个最大值,并将其他值设为负无穷。通过调整 k 的值,可以控制生成样本的多样性和准确性。当 k 较大时,生成的样本具有更多的多样性,但可能出现较多的错误;当 k 较小时,生成的样本更加准确,但可能缺乏多样性。TopKLogitsWarper 通过对 logits 进行 top-k 截断来实现多样性和准确性之间的平衡。
接口详情
~GenerateMixin.generate()
方法用于生成文本。它的输入参数包括:
input_ids
:一个形状为[batch_size, sequence_length]
的整数张量,表示输入序列。attention_mask
:一个形状为[batch_size, sequence_length]
的浮点数张量,表示输入序列中哪些位置是有效的。**kwargs
:其他参数,例如decoder_input_ids
、past
等,具体取决于所使用的模型。
该方法的输出为:
output
:一个形状为[batch_size, sequence_length, vocabulary_size]
的浮点数张量,表示生成的文本的概率分布。
~GenerateMixin.contrastive_search()
方法用于执行对比搜索(contrastive search)。它的输入参数包括:
input_ids
:一个形状为[batch_size, sequence_length]
的整数张量,表示输入序列。attention_mask
:一个形状为[batch_size, sequence_length]
的浮点数张量,表示输入序列中哪些位置是有效的。num_return_sequences
:一个整数,表示要返回的生成序列的数量。**kwargs
:其他参数,例如decoder_input_ids
、past
等,具体取决于所使用的模型。
该方法的输出为:
output
:一个形状为[num_return_sequences, sequence_length]
的整数张量,表示生成的文本序列。
~GenerateMixin.greedy_search()
方法用于执行贪心搜索(greedy search)。它的输入参数包括:
-
input_ids
:一个形状为[batch_size, sequence_length]
的整数张量,表示输入序列。 -
attention_mask
:一个形状为[batch_size, sequence_length]
的浮点数张量,表示输入序列中哪些位置是有效的。 -
num_return_sequences
:一个整数,表示要返回的生成序列的数量。 -
**kwargs
:其他参数,例如decoder_input_ids
、past
等,具体取决于所使用的模型。
该方法的输出为: -
output
:一个形状为[num_return_sequences, sequence_length]
的整数张量,表示生成的文本序列。
~GenerateMixin.sample()
方法用于执行随机采样(random sampling)。它的输入参数包括:
input_ids
:一个形状为[batch_size, sequence_length]
的整数张量,表示输入序列。attention_mask
:一个形状为[batch_size, sequence_length]
的浮点数张量,表示输入序列中哪些位置是有效的。num_return_sequences
:一个整数,表示要返回的生成序列的数量。**kwargs
:其他参数,例如decoder_input_ids
、past
等,具体取决于所使用的模型。
该方法的输出为:
output
:一个形状为[num_return_sequences, sequence_length]
的整数张量,表示生成的文本序列
~GenerateMixin.beam_search()
方法用于执行束搜索(beam search)。它的输入参数包括:
input_ids
:一个形状为[batch_size, sequence_length]
的整数张量,表示输入序列。attention_mask
:一个形状为[batch_size, sequence_length]
的浮点数张量,表示输入序列中哪些位置是有效的。num_return_sequences
:一个整数,表示要返回的生成序列的数量。**kwargs
:其他参数,例如decoder_input_ids
、past
等,具体取决于所使用的模型。
该方法的输出为:
output
:一个形状为[num_return_sequences, sequence_length]
的整数张量,表示生成的文本序列。
~GenerateMixin.beam_sample()
方法用于执行束采样(beam sampling)。它的输入参数包括:
input_ids
:一个形状为[batch_size, sequence_length]
的整数张量,表示输入序列。attention_mask
:一个形状为[batch_size, sequence_length]
的浮点数张量,表示输入序列中哪些位置是有效的。num_return_sequences
:一个整数,表示要返回的生成序列的数量。**kwargs
:其他参数,例如decoder_input_ids
、past
等,具体取决于所使用的模型。
该方法的输出为:
output
:一个形状为[num_return_sequences, sequence_length]
的整数张量,表示生成的文本序列。
~GenerateMixin.group_beam_search()
方法用于执行分组束搜索(group beam search)。它的输入参数包括:
input_ids
:一个形状为[batch_size, sequence_length]
的整数张量,表示输入序列。attention_mask
:一个形状为[batch_size, sequence_length]
的浮点数张量,表示输入序列中哪些位置是有效的。num_return_sequences
:一个整数,表示要返回的生成序列的数量。**kwargs
:其他参数,例如decoder_input_ids
、past
等,具体取决于所使用的模型。
该方法的输出为:
output
:一个形状为[num_return_sequences, sequence_length]
的整数张量,表示生成的文本序列。
~GenerateMixin.constrained_beam_search()
方法用于执行约束束搜索(constrained beam search)。它的输入参数包括:
input_ids
:一个形状为[batch_size, sequence_length]
的整数张量,表示输入序列。attention_mask
:一个形状为[batch_size, sequence_length]
的浮点数张量,表示输入序列中哪些位置是有效的。constraints
:一个列表,其中每个元素都是一个形状为[batch_size, sequence_length]
的整数张量,表示相应位置的限制条件。num_return_sequences
:一个整数,表示要返回的生成序列的数量。**kwargs
:其他参数,例如decoder_input_ids
、past
等,具体取决于所使用的模型。
该方法的输出为:
output
:一个形状为[num_return_sequences, sequence_length]
的整数张量,表示生成的文本序列。