生成式AI文本密碼:Transformer參數(shù)全解碼? 原創(chuàng)
本文詳細(xì)介紹Transformer模型中控制文本生成的關(guān)鍵參數(shù),包括溫度、Top-K和Top-P采樣、重復(fù)懲罰等,并探討這些參數(shù)對(duì)生成文本質(zhì)量的影響及針對(duì)不同應(yīng)用的調(diào)整方法。
Transformer模型是當(dāng)今NLP任務(wù)的標(biāo)準(zhǔn)模型。幾乎所有NLP任務(wù)都涉及文本生成,但文本生成并非模型的直接輸出。你可能希望模型能夠幫助你生成連貫且與上下文相關(guān)的文本。雖然這在一定程度上與模型的質(zhì)量有關(guān),但生成參數(shù)也對(duì)生成文本的質(zhì)量起著至關(guān)重要的作用。
在本文中,讓我們來(lái)一起探索控制Transformer模型中文本生成的關(guān)鍵參數(shù)。你將了解這些參數(shù)如何影響生成文本的質(zhì)量,以及如何針對(duì)不同的應(yīng)用進(jìn)行調(diào)整。具體而言,你將學(xué)習(xí)到:
- Transformer模型中控制文本生成的核心參數(shù)?
- 不同的解碼策略?
- 如何控制生成文本的創(chuàng)造性和連貫性?
- 如何針對(duì)特定應(yīng)用微調(diào)生成參數(shù)?
讓我們開始吧!
概述
本文將劃分為七個(gè)部分進(jìn)行介紹,它們是:
- 核心文本生成參數(shù)?
- 溫度實(shí)驗(yàn)?
- Top-K和Top-P采樣?
- 控制重復(fù)?
- 貪婪解碼和采樣?
- 特定應(yīng)用的參數(shù)?
- 集束搜索和多序列生成
核心文本生成參數(shù)
我們以GPT-2模型為例。它是一個(gè)小型Transformer模型,不需要大量計(jì)算資源,但仍能生成高質(zhì)量的文本。使用GPT-2模型生成文本的一個(gè)簡(jiǎn)單示例如下:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
#創(chuàng)建模型和分詞器
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
#將輸入提示分詞為ID序列
prompt = "Artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
# 將輸出作為一系列標(biāo)記ID生成
output = model.generate(
**inputs,
max_length=50,
num_return_sequences=1,
temperature=1.0,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
#將標(biāo)記ID轉(zhuǎn)換為文本字符串
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Prompt: {prompt}")
print("Generated Text:")
print(generated_text)
如果運(yùn)行此代碼,你可能會(huì)看到如下輸出內(nèi)容:
Prompt: Artificial intelligence is
Generated Text:
Artificial intelligence is used in the production of technology, the delivery of
which is determined by technological change. For example, an autonomous car can
change its steering wheel to help avoid driving traffic. In the case of artificial
intelligence, this can change what consumers
本例中,你只提供了三個(gè)單詞的提示,模型就生成了一段很長(zhǎng)的文本。這并非一次性生成,而是在迭代過(guò)程中多次調(diào)用模型。
你可以看到generate()函數(shù)中使用的眾多參數(shù)。你使用的第一個(gè)參數(shù)是max_length,它控制生成的文本的長(zhǎng)度(以標(biāo)記數(shù)量表示)。通常,模型使用提示作為上下文,一次生成一個(gè)標(biāo)記。然后,將新生成的標(biāo)記附加到提示中并生成下一個(gè)標(biāo)記。因此,你希望生成的文本越長(zhǎng),生成它所需的時(shí)間就越長(zhǎng)。請(qǐng)注意,這里關(guān)注的是標(biāo)記,而不是單詞,因?yàn)槟阍贕PT-2模型中使用了子詞標(biāo)記器。一個(gè)標(biāo)記可能只是一個(gè)子詞單元,而不是一個(gè)完整的單詞。
然而,該模型并非專門生成任何單個(gè)標(biāo)記。相反,它生成一個(gè)“l(fā)ogit”,即下一個(gè)標(biāo)記概率的向量。logit是一個(gè)長(zhǎng)向量,恰好與詞匯表的大小相同。鑒于它是所有可能的“下一個(gè)標(biāo)記”的概率分布,你可以選擇概率最高的標(biāo)記(當(dāng)設(shè)置do_sample=False時(shí)),或者任何其他概率非零的標(biāo)記(當(dāng)設(shè)置do_sample=True時(shí))。這就是所有其他參數(shù)的目的。
temperature參數(shù)會(huì)扭曲概率分布。較低的溫度會(huì)強(qiáng)調(diào)最可能的標(biāo)記,而較高的溫度會(huì)縮小可能的標(biāo)記和不太可能的標(biāo)記之間的差異。默認(rèn)溫度為1.0,并且應(yīng)為正值。然后,top_k參數(shù)僅選擇最靠前的標(biāo)記標(biāo)記,而不是整個(gè)標(biāo)記詞匯表。然后重新計(jì)算概率,總和為1。接下來(lái),如果設(shè)置了top_p,則這一組k個(gè)標(biāo)記的集合進(jìn)一步過(guò)濾,保留構(gòu)成總概率p的那些頂級(jí)標(biāo)記。然后使用這組最終的標(biāo)記來(lái)對(duì)下一個(gè)標(biāo)記進(jìn)行采樣,這個(gè)過(guò)程稱為核采樣。
請(qǐng)記住,你正在生成一個(gè)標(biāo)記序列,一次一個(gè)。你很可能會(huì)在每一步中重復(fù)看到相同的標(biāo)記,并且你可能會(huì)在序列中看到相同的標(biāo)記。這通常不是你想要的結(jié)果,因此你可能希望在再次看到這些標(biāo)記時(shí)降低其出現(xiàn)的概率。這就是上面repetition_penalty參數(shù)的作用所在。
溫度實(shí)驗(yàn)
假設(shè)到目前你已經(jīng)知道了各個(gè)參數(shù)的作用,那么接下來(lái),讓我們看看當(dāng)你調(diào)整其中一些參數(shù)時(shí)輸出如何變化。
溫度參數(shù)對(duì)生成文本的創(chuàng)造性和隨機(jī)性有顯著的影響。你可以通過(guò)以下示例看到其效果:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
# 生成不同溫度值的文本
temperatures = [0.2, 0.5, 1.0, 1.5]
print(f"Prompt: {prompt}")
for temp in temperatures:
print()
print(f"Temperature: {temp}")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=temp,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
如果運(yùn)行此代碼,你可能會(huì)看到如下輸出內(nèi)容:
Prompt: The future of artificial intelligence is
Temperature: 0.2
Generated Text:
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future
Temperature: 0.5
Generated Text:
The future of artificial intelligence is uncertain.
"There is a lot of work to be done on this," said Eric Schmitt, a professor
of computer science and engineering at the University of California, Berkeley.
"We're looking for a way to make AI more like computers. We need to take a step
back and look at how we think about it and how we interact with it."
Schmitt said he's confident that artificial intelligence will eventually be
able to do more than
Temperature: 1.0
Generated Text:
The future of artificial intelligence is not yet clear, however."
"Is the process that we are trying to do through computer vision and the ability to
look at a person at multiple points without any loss of intelligence due to not
seeing a person at multiple points?" asked Richard. "I also think the people who
are doing this research are extremely interesting to me due to being able to see
humans at a range of different points in time. In particular, they've shown how
to do a pretty complex
Temperature: 1.5
Generated Text:
The future of artificial intelligence is an era to remember as much as Google in
search results, particularly ones not supported by much else for some years -- and
it might look like the search giant is now just as good without artificial
intelligence. [Graphic image from Shutterstock]
當(dāng)溫度較低(例如0.2)時(shí),文本會(huì)變得更加集中和確定,通常會(huì)堅(jiān)持使用常用短語(yǔ)和傳統(tǒng)觀點(diǎn)。你還會(huì)看到,由于概率集中在少數(shù)幾個(gè)標(biāo)記上,文本會(huì)不斷重復(fù)相同的句子,從而限制了多樣性。這個(gè)問(wèn)題可以通過(guò)使用重復(fù)懲罰參數(shù)來(lái)解決,該參數(shù)將在下一節(jié)中介紹。
中等溫度(例如0.5到1.0)的文本在連貫性和創(chuàng)造性之間取得了良好的平衡。生成的文本可能并非基于事實(shí),但語(yǔ)言自然。
當(dāng)溫度較高(例如1.5)時(shí),文本會(huì)變得更加隨意和富有創(chuàng)意,但也可能變得缺乏連貫性,有時(shí)甚至缺乏邏輯性。語(yǔ)言可能難以理解,就像上面的例子一樣。
選擇合適的溫度取決于你的應(yīng)用。如果你正在創(chuàng)建代碼補(bǔ)全或?qū)懽髦?,通常較低的溫度更佳。對(duì)于創(chuàng)意寫作或頭腦風(fēng)暴,較高的溫度可以產(chǎn)生更多樣化、更有趣的結(jié)果。
Top-K和Top-P采樣
核采樣參數(shù)控制著模型選擇下一個(gè)標(biāo)記的靈活性。你應(yīng)該調(diào)整top_k參數(shù)還是top_p參數(shù)?讓我們通過(guò)一個(gè)例子來(lái)看一下它們的效果:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The best way to learn programming is"
inputs = tokenizer(prompt, return_tensors="pt")
#使用不同top_k值生成文本
top_k_values = [5, 20, 50]
print(f"Prompt: {prompt}")
for top_k in top_k_values:
print()
print(f"Top-K = {top_k}")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=1.0,
top_k=top_k,
top_p=1.0,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
# 使用不同top_p值生成文本
top_p_values = [0.5, 0.7, 0.9]
for top_p in top_p_values:
print()
print(f"Top-P = {top_p}")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=1.0,
top_k=0,
top_p=top_p,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
如果運(yùn)行此代碼,你可能會(huì)看到如下輸出內(nèi)容:
Prompt: The best way to learn programming is
Top-K = 5
Generated Text:
The best way to learn programming is to be able to learn the basics in a very short
amount of time, and then learn to use them effectively and quickly.
If you want to be a successful programmer in this way, you should learn to use the
techniques in the above video to learn the basics of programming.
If you want to learn to code more effectively, you can also get more experienced
programmers by doing the following:
Learning to Code
Learning to code is very
Top-K = 20
Generated Text:
The best way to learn programming is to learn it.
In order to get started with Ruby you're going to have to make a few mistakes, some
of them can be fairly obvious.
First of all, you're going to have to write a function that takes in a value. What
this means is that you're going to make a new instance of the Ruby function. You can
read more about this in Part 1 of this course, or just try it out from the REPL.
Top-K = 50
Generated Text:
The best way to learn programming is to become familiar with the language and the
software. One of the first and most common forms of programming is to create,
modify, and distribute code.
However, there are very few programming libraries that can provide us with all
that we need.
The following sample programming program uses some of the above, but does not show
the best way to learn programming. It was written in Java and in C or C++.
The original source code is
Top-P = 0.5
Generated Text:
The best way to learn programming is to be able to create a tool for you. That's
what I do.
That's why I'm here today.
I'm here to talk about the basics of programming, and I'm going to tell you how to
learn programming.
I'm here to talk about learning programming.
It's easy to forget that you don't have to know how to program. It's easy to forget
that you don't have to know how
Top-P = 0.7
Generated Text:
The best way to learn programming is to practice programming. Learn the principles
of programming by observing and performing exercises.
I used to work in a world of knowledge which included all sorts of things, and was
able to catch up on them and understand them from their perspective. For instance, I
learned to sit up straight and do five squats. Then, I would have to practice some
type of overhead training. I would try to learn the best technique and add that to
my repertoire.
What
Top-P = 0.9
Generated Text:
The best way to learn programming is to become a good hacker. Don't use any
programming tools. Just a regular dot-com user, an occasional coding learner, and
stick with it.
— Victoria E. Nichols
你可以通過(guò)一個(gè)小的k值,例如5,看到模型可供選擇的選項(xiàng)較少,從而導(dǎo)致文本更可預(yù)測(cè)。在極端情況下,當(dāng)k=1時(shí),模型總是選擇概率最高的單個(gè)標(biāo)記,這是貪婪解碼,通常會(huì)產(chǎn)生較差的輸出。當(dāng)使用一個(gè)較大的k值,比如50,模型就有更多的選項(xiàng)可以選擇,從而產(chǎn)生更加多樣化的文本。
類似地,對(duì)于top_p參數(shù),較小的p值意味著模型從一組較小的高概率標(biāo)記中進(jìn)行選擇,從而產(chǎn)生更有針對(duì)性的文本。使用較大的p值,例如0.9,則模型的選擇范圍更廣,可能會(huì)產(chǎn)生更多樣化的文本。但是,對(duì)于給定的文本,你可以選擇多少個(gè)選項(xiàng)并非固定不變,它取決于模型預(yù)測(cè)的概率分布。當(dāng)模型對(duì)下一個(gè)標(biāo)記非常有信心時(shí)(例如受某些語(yǔ)法規(guī)則限制),只允許使用非常小的標(biāo)記集合。這種自適應(yīng)特性也是為什么top-p采樣通常比top-k采樣更受歡迎的原因。
控制重復(fù)
重復(fù)是文本生成中常見的問(wèn)題。repetition_penalty參數(shù)通過(guò)懲罰已在生成文本中出現(xiàn)過(guò)的標(biāo)記來(lái)幫助解決這個(gè)問(wèn)題。讓我們看看它是如何工作的:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "Once upon a time, there was a"
inputs = tokenizer(prompt, return_tensors="pt")
# 使用不同的重復(fù)懲罰生成文本
penalties = [1.0, 1.2, 1.5, 2.0]
print(f"Prompt: {prompt}")
for penalty in penalties:
print()
print(f"Repetition penalty: {penalty}")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=0.3,
top_k=50,
top_p=1.0,
repetition_penalty=penalty,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
如果運(yùn)行此代碼,你可能會(huì)看到如下輸出內(nèi)容:
Prompt: Once upon a time, there was a
Repetition penalty: 1.0
Generated Text:
Once upon a time, there was a great deal of confusion about what was going on. The
first thing that came to mind was the fact that the government had already been in
place for a long time, and that the government had been in place for a long time.
And it was clear that the government had been in place for a long time. And it was
clear that the government had been in place for a long time. And it was clear that
the government had been in place for a long
Repetition penalty: 1.2
Generated Text:
Once upon a time, there was a great deal of talk about the possibility that this
would be an opportunity for us to see more and better things in our lives. We had
been talking on Facebook all day long with people who were interested in what we
could do next or how they might help others find their own way out."
"We've always wanted to make sure everyone has access," he continued; "but it's not
like you can just go into your room at night looking around without seeing
Repetition penalty: 1.5
Generated Text:
Once upon a time, there was a man who had been called to the service of God. He
came and said: "I am an apostle from Jerusalem." And he answered him with great joy,
saying that it is not possible for me now in this life without having received
Jesus Christ as our Lord; but I will be saved through Him alone because my Father
has sent Me into all things by His Holy Spirit (John 1).
The Christian Church teaches us how much more than any other religion can
Repetition penalty: 2.0
Generated Text:
Once upon a time, there was a man who had been sent to the city of Nausicaa by his
father. The king's son and brother were killed in battle at that place; but when
he returned with them they found him dead on their way back from war-time.[1]
The King gave orders for an expedition against this strange creature called "the
Gorgon," which came out into space during one night after it attacked Earth[2]. It
is said that these creatures
在上面的代碼中,為了強(qiáng)調(diào)重復(fù)懲罰的效果,我們將溫度設(shè)置為0.3。當(dāng)懲罰值較低(例如1.0)時(shí),你可以看到模型一遍又一遍地重復(fù)同一個(gè)短語(yǔ)。當(dāng)其他設(shè)置將候選標(biāo)記限制在較小的子集時(shí),模型很容易陷入循環(huán)。但是,當(dāng)懲罰值較高(例如2.0或更高)時(shí),模型會(huì)強(qiáng)烈避免重復(fù),這有時(shí)會(huì)導(dǎo)致文本的自然性降低。中等懲罰值(例如1.2到1.5)通常是保持連貫性的良好折衷方案。
畢竟,generate()函數(shù)中設(shè)置的參數(shù)是為了保持文本自然流暢。你可能需要通過(guò)實(shí)驗(yàn)來(lái)調(diào)整這些參數(shù),以找到最適合你特定應(yīng)用的參數(shù)。請(qǐng)注意,這些參數(shù)可能取決于你使用的模型,因?yàn)槊總€(gè)模型生成的標(biāo)記可能具有不同的分布。
貪婪解碼和采樣
do_sample參數(shù)控制模型是使用采樣(基于概率選擇標(biāo)記)還是貪婪解碼(始終選擇最可能的標(biāo)記)。讓我們比較一下這兩種方法:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The secret to happiness is"
inputs = tokenizer(prompt, return_tensors="pt")
# 使用貪婪解碼與采樣生成文本
print(f"Prompt: {prompt}\n")
print("Greedy Decoding (do_sample=False):")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=1.0,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
print()
print("Sampling (do_sample=True):")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=1.0,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
嘗試多次運(yùn)行此代碼并觀察輸出結(jié)果。你會(huì)注意到,貪婪解碼的輸出始終相同,而采樣的輸出每次都不同。對(duì)于固定的提示,貪婪解碼是確定性的。該模型生成概率分布,并選擇最可能的標(biāo)記,不涉及隨機(jī)性,輸出更有可能重復(fù)且無(wú)用。
采樣輸出是隨機(jī)的,因?yàn)檩敵鰳?biāo)記是根據(jù)模型預(yù)測(cè)的概率分布選擇的。這種隨機(jī)性使模型能夠生成更加多樣化和富有創(chuàng)意的文本;同時(shí),只要其他生成參數(shù)設(shè)置得當(dāng),輸出仍然保持一致。在采樣輸出的情況下,你可以將num_return_sequences設(shè)置為大于1的數(shù)字,以便為同一提示并行生成多個(gè)序列。此參數(shù)對(duì)于貪婪解碼毫無(wú)意義。
特定應(yīng)用的參數(shù)
對(duì)于特定的應(yīng)用,應(yīng)該設(shè)置哪些參數(shù)值?并沒有明確的答案。你肯定需要進(jìn)行一些實(shí)驗(yàn)來(lái)找到最佳組合。但是,你可以參考以下建議:
- 事實(shí)生成:
?A.提供更低的temperature參數(shù)值(0.2至0.4)以獲得更確定的輸出
B.使用中等大小的top_p參數(shù)值(0.8到0.9),過(guò)濾掉不太可能的標(biāo)記?
C.使用更高的repetition_penalty參數(shù)值(1.2至1.5),以避免重復(fù)陳述?
- 創(chuàng)意寫作:
?A.提供更高一些的temperature參數(shù)值(1.0到1.3),可實(shí)現(xiàn)更具創(chuàng)意和多樣化的輸出
B.提供更高的top_p參數(shù)值(0.9到0.95),以提供更多可能性
C.提供較低的repetition_penalty參數(shù)值(1.0到1.1),以允許一些風(fēng)格重復(fù)?
- 代碼生成:
?A.提供更低的temperature參數(shù)值(0.1到0.3),可獲得更精確、更正確的代碼
B.提供較低的top_p參數(shù)值(0.7至0.8),以關(guān)注最可能的標(biāo)記?
C.提供更高的repetition_penalty參數(shù)值(1.3到1.5),以避免冗余代碼?
- 對(duì)話生成:?
A.提供中等大小的temperature參數(shù)值(0.6至0.8),反應(yīng)自然但集中
B.提供中等大小的top_p參數(shù)值(0.9),創(chuàng)造力和連貫性達(dá)到良好平衡
C.提供中等大小的repetition_penalty參數(shù)值(1.2),避免重復(fù)的短語(yǔ)
請(qǐng)記住,語(yǔ)言模型并非完美的預(yù)言機(jī),它也可能會(huì)出錯(cuò)。上述參數(shù)旨在幫助你將生成過(guò)程與預(yù)期的輸出風(fēng)格相匹配,但并不能保證其正確性。你得到的輸出可能包含錯(cuò)誤。
集束搜索和多序列生成
在上面的例子中,生成過(guò)程是自回歸的。它是一個(gè)迭代過(guò)程,每次生成一個(gè)標(biāo)記。
由于每個(gè)步驟都會(huì)通過(guò)采樣生成一個(gè)標(biāo)記,因此你可以同時(shí)生成多個(gè)標(biāo)記。這樣一來(lái),你將為一個(gè)輸入提示生成多個(gè)輸出序列。理論上,如果你每一步生成k個(gè)標(biāo)記,并且設(shè)置返回的長(zhǎng)度為n,你將生成kn個(gè)序列。這個(gè)數(shù)字可能很大,你可能希望將其限制為幾個(gè)。
生成多個(gè)序列的第一種方法是設(shè)置num_return_sequences為數(shù)字k。你在第一步中生成k個(gè)標(biāo)記。然后,完成每個(gè)標(biāo)記的序列。這基本上確定了在生成中復(fù)制了提示k次。
第二種方法是使用集束搜索。這是一種生成多個(gè)序列的更復(fù)雜的方法。它會(huì)跟蹤最有希望的序列并并行探索它們。它不是生成kn個(gè)序列以淹沒記憶,它只保留每一步的最佳序列。每個(gè)標(biāo)記生成步驟都會(huì)暫時(shí)擴(kuò)展這個(gè)集合,然后將其修剪回最佳序列。
要使用集束搜索,你需要設(shè)置num_beams為一個(gè)數(shù)字k。每一步都會(huì)擴(kuò)大k個(gè)序列以再添加一個(gè)標(biāo)記,結(jié)果生成k2個(gè)序列,然后選擇最佳k個(gè)序列繼續(xù)下一步。你還可以通過(guò)設(shè)置early_stopping=True,以便在到達(dá)序列末尾時(shí)停止生成。你還應(yīng)該設(shè)置num_return_sequences在輸出時(shí)限制最終選擇。
序列的選擇通常基于序列中標(biāo)記的累積概率。但你也可以通過(guò)其他標(biāo)準(zhǔn)來(lái)調(diào)整選擇,例如添加長(zhǎng)度懲罰或避免重復(fù)n-grams。以下是使用集束搜索的示例:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The key to successful machine learning is"
inputs = tokenizer(prompt, return_tensors="pt")
#使用貪婪解碼與采樣生成文本
print(f"Prompt: {prompt}\n")
outputs = model.generate(
**inputs,
num_beams=5, # 要使用的光束數(shù)量
early_stopping=True, # 當(dāng)所有光束都完成時(shí)停止
no_repeat_ngram_size=2, # 避免重復(fù)n-gram
num_return_sequences=3, # 返回多個(gè)序列
max_length=100,
temperature=1.5,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
for idx, output in enumerate(outputs):
generated_text = tokenizer.decode(output, skip_special_tokens=True)
print(f"Generated Text ({idx+1}):")
print(generated_text)
你可以添加更多生成參數(shù)(例如length_penalty)來(lái)控制生成過(guò)程。上面的示例設(shè)置了更高的溫度,以突出集束搜索的輸出。運(yùn)行此代碼,你可能會(huì)看到:
Prompt: The key to successful machine learning is
Generated Text (1):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.
So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how you can use them to create
Generated Text (2):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.
So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how you can use them and what
Generated Text (3):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.
So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how they work. You can use
輸出序列的數(shù)量仍然受num_return_sequences控制,但生成序列的過(guò)程使用了集束搜索算法。不過(guò),從輸出結(jié)果很難判斷是否使用了集束搜索。一個(gè)跡象是,集束搜索的輸出不像單純的設(shè)置num_return_sequences那樣具有多樣性,因?yàn)樯傻男蛄懈嗖⑶疫x擇了累積概率更高的序列。這種過(guò)濾確實(shí)降低了輸出的多樣性。
進(jìn)一步閱讀
以下是一些你可能覺得有用的補(bǔ)充閱讀材料:
總結(jié)
在本文中,你了解了如何使用generate()函數(shù)中的眾多參數(shù)來(lái)控制生成過(guò)程。你可以調(diào)整這些參數(shù),使輸出符合你應(yīng)用程序的預(yù)期樣式。具體來(lái)說(shuō),你學(xué)習(xí)了:
- 如何利用溫度來(lái)控制輸出的概率分布?
- 如何使用top-k和top-p來(lái)控制輸出的多樣性?
- 如何使用重復(fù)懲罰、集束搜索和貪婪解碼來(lái)控制輸出?
通過(guò)理解和調(diào)整這些參數(shù),你可以優(yōu)化不同應(yīng)用的文本生成,從事實(shí)寫作到創(chuàng)意敘事、代碼生成和對(duì)話系統(tǒng)等各個(gè)領(lǐng)域。
譯者介紹
朱先忠,51CTO社區(qū)編輯,51CTO專家博客、講師,濰坊一所高校計(jì)算機(jī)教師,自由編程界老兵一枚。
原文標(biāo)題:??Understanding Text Generation Parameters in Transformers??,作者:Muhammad Asad Iqbal Khan
