MDQA 知識圖譜提示用于多文檔問答
論文閱讀
該論文提出了一種知識圖譜提示(KGP)方法,以構(gòu)建正確的上下文來提示LLMs進行MD-QA,該方法包括一個圖構(gòu)建模塊和一個圖遍歷模塊。在圖構(gòu)建方面,創(chuàng)建了一個跨越多文檔的知識圖譜(KG),邊表示段落或文檔結(jié)構(gòu)之間的語義/詞匯相似性。在圖遍歷方面,我們設(shè)計了一個基于LLMs的圖遍歷代理,該代理在節(jié)點間導(dǎo)航并收集支持性段落,以幫助LLMs進行MD-QA。所構(gòu)建的圖作為全局規(guī)則器,調(diào)節(jié)段落之間的過渡空間并減少檢索延遲。同時,圖遍歷代理充當(dāng)一個本地導(dǎo)航器,收集相關(guān)上下文以逐步接近問題并保證檢索質(zhì)量。
我們平常做RAG文本召回的時候,也不會只針對一個文檔做召回,本質(zhì)上也是多文檔的召回。該文章在傳統(tǒng)的RAG召回的基礎(chǔ)之上,增加了文章、段落節(jié)點。在每個段落之間添加了邊,從而實現(xiàn)一種遞歸的文本召回(找到一個與問題相似的段落節(jié)點后,在該段落節(jié)點的鄰接的節(jié)點,也進行相似查找)。如下圖右側(cè)所示,一篇文章上面所有內(nèi)容,包括表格、段落等都掛在到一個文章節(jié)點上。(以前我也有過這樣的想法,也做了文章結(jié)構(gòu)的知識圖譜,但沒有找到可以講故事的地方)。下圖右側(cè)的段落節(jié)點之間的邊,代表這兩個節(jié)點很相似。
段落之間用相似度構(gòu)建邊,做成可視化,呈現(xiàn)給用戶一種直觀的感覺是可以的。但是他們把這種加入到召回文本中,讓大模型去回答,我個人認為這里不一定能夠提升效果。因為他們對文本召回的檢索器進行了微調(diào),所以模型的效果肯定好,他們應(yīng)該要做一個段落臨接節(jié)點的消融實驗,證明在段落節(jié)點之間添加相似邊是有效的。
實驗部分:
在這篇文章的源碼中,可以學(xué)到數(shù)據(jù)集的構(gòu)建,KNN、TF-IDF、BM25等這些檢索器的使用。
該論文沒有給出召回率方面的評估結(jié)果,直接給出最終的結(jié)果。他們評估大模型回答問題答案的效果,采用的是大模型打分的方法,提示詞如下:
def prompt_eval():
eval_prompt = """You are an expert professor specialized in grading whether the prediction to the question is correct or not according to the real answer.
==================
For example:
==================
Question: What company owns the property of Marvel Comics?
Answer: The Walt Disney Company
Prediction: The Walt Disney Company
Return: 1
==================
Question: Which constituent college of the University of Oxford endows four professorial fellowships for sciences including chemistry and pure mathematics?
Answer: Magdalen College
Prediction: Magdalen College.
Return: 1
==================
Question: Which year was Marvel started?
Answer: 1939
Prediction: 1200
Return: 0
==================
You are grading the following question:
Question: {question}
Answer: {answer}
Prediction: {prediction}
If the prediction is correct according to answer, return 1. Otherwise, return 0.
Return: your reply can only be one number '0' or '1'
"""
return eval_prompt
If the prediction is correct according to answer, return 1. Otherwise, return 0.
把大模型生成的答案與真實的答案一起提交給評估的模型,如果預(yù)測的結(jié)果是對的返回1,預(yù)測結(jié)果不對返回0。
評估結(jié)果的測試腳本 ??Pipeline/evaluation/eval.ipynb?
?:
代碼解析
圖譜構(gòu)建
??Data-Collect/graph_construct.py?
?
def knn_graph(i_d, k_knn, embs, strategy='cos'):
idx, d = i_d
emb = embs[idx]
# build a knn Graph
if strategy == 'cos':
sim = cosine_similarity(emb, emb)
elif strategy == 'dp':
sim = np.matmul(emb, emb.transpose(1, 0))
# topk
top_idx = np.argsort(-sim, axis=1)[:, 1:k_knn + 1]
tail_nodes = np.arange(top_idx.shape[0]).repeat(k_knn) # flatten
head_nodes = top_idx.reshape(-1)
edges = [(node1, node2) for node1, node2 in zip(tail_nodes, head_nodes)]
G = nx.DiGraph()
G.add_edges_from(edges)
return idx, G
上述代碼實現(xiàn)了,兩個節(jié)點根據(jù)它倆之間向量相似度構(gòu)建邊。
檢索器微調(diào)
主要關(guān)注 橋接問題,因為比較問題不需要關(guān)注順序,先召回哪一個文本都行。針對橋接問題首先需要能夠?qū)召回S1,然后再對 Q+S1 能夠召回S2。相對傳統(tǒng)的檢索器微調(diào)需要增加Q+S1能夠?qū)W會召回S2的過程。所以這一點,在下述的數(shù)據(jù)集構(gòu)造中多了??q1_c1_enc?
??,在損失值的計算中多了 ??loss_fct(scores_2, target_2)?
?。
數(shù)據(jù)集:
- q_enc: 問題的嵌入向量
- q_c1: 問題+第一個文本的嵌入向量
- c1_enc、c2_enc:真實的第一個文本與第二個文本
- n1_enc、n2_enc:從負樣本中隨機篩選出的兩個負樣本
損失函數(shù):
def mp_loss(model, batch):
embs = model(batch)
loss_fct = CrossEntropyLoss(ignore_index = -1)
c_embs = torch.cat([embs["c1_emb"], embs["c2_emb"]], dim = 0) # 2B x d
n_embs = torch.cat([embs["n1_emb"].unsqueeze(1), embs["n2_emb"].unsqueeze(1)], dim = 1) # B*2*M*h
scores_1 = torch.mm(embs["q_emb"], c_embs.t()) # B x 2B
n_scores_1 = torch.bmm(embs["q_emb"].unsqueeze(1), n_embs.permute(0, 2, 1)).squeeze(1) # B x 2B
scores_2 = torch.mm(embs["q_c1_emb"], c_embs.t()) # B x 2B
n_scores_2 = torch.bmm(embs["q_c1_emb"].unsqueeze(1), n_embs.permute(0, 2, 1)).squeeze(1) # B x 2B
# mask the 1st hop
bsize = embs["q_emb"].size(0)
scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(embs["q_emb"].device)
scores_1 = scores_1.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1)
scores_1 = torch.cat([scores_1, n_scores_1], dim=1)
scores_2 = torch.cat([scores_2, n_scores_2], dim=1)
target_1 = torch.arange(embs["q_emb"].size(0)).to(embs["q_emb"].device)
target_2 = torch.arange(embs["q_emb"].size(0)).to(embs["q_emb"].device) + embs["q_emb"].size(0)
loss = loss_fct(scores_1, target_1) + loss_fct(scores_2, target_2)
return loss
- loss_fct(scores_1, target_1): 模型學(xué)會通過 Q 召回S1;
- loss_fct(scores_2, target_2):模型學(xué)會通過 Q+S1 能夠召回S2;
上述的損失函數(shù)寫的挺復(fù)雜的,如果第一次看到這種檢索器的損失函數(shù),應(yīng)該會有很多同學(xué)看不懂。
關(guān)于檢索器微調(diào)損失值:這里的損失函數(shù)是 CrossEntropyLoss 與分類挺像的,把問題的向量與相關(guān)文本做乘法,得到的是問題的向量與相關(guān)文本的相似度的值。兩個向量做乘法得到的是這兩個向量相似度。 這個損失函數(shù)的就是讓正確文本對應(yīng)的相似度的值足夠大,損失值才會小。
如果BGE檢索器的微調(diào)還不熟悉的話,也不用硬看上述代碼,時間充裕的話,可以先看懂BGE檢索器微調(diào)。transformers二次開發(fā)——(定義自己的數(shù)據(jù)加載器 模型 訓(xùn)練器)bge模型微調(diào)流程 這是一個B站的視頻講解的BGE微調(diào)的,但是該視頻有一點遺憾的地方,在關(guān)鍵的損失值計算部分,該UP主講解錯,后來他也在評論區(qū)進行了回應(yīng)。如果大家想深入了解BGE微調(diào),進入 https://github.com/FlagOpen/FlagEmbedding 倉庫,找到23年10月的版本(新版本代碼太多了,舊版本代碼很簡潔),一步一步debug,后面自然就會懂。
為了防止我以后忘記,簡單寫幾句:
??scores_1 = torch.mm(embs["q_emb"], c_embs.t())?
? 把問題的向量與所有候選文本的向量做一個乘法。
??scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(embs["q_emb"].device)?
?? 這里使用了mask,把??c2_emd?
? 給遮罩掉。(在看懂代碼前,我就想到了要遮罩c2_emb,然后發(fā)現(xiàn)他果然做了遮罩)
因為通過 q_emb 學(xué)會召回 c1_emb。通過 q_c1_emb 才應(yīng)該學(xué)會召回c2_emb。
對于scores_1的損失函數(shù)而言,正確的 label 給了c1_emb,c2_emb自然就是錯誤。c2_emb會成為負樣本,這是不允許的,這樣會把 q_emb 與 c2_emb 的相似程度給拉遠了,這樣不行,最好的做法還是把 c2_emb 給遮罩掉。
對于 target_2 ??torch.arange(embs["q_emb"].size(0)).to(embs["q_emb"].device) + embs["q_emb"].size(0)?
? 在label數(shù)值加的embs["q_emb"].size(0)是batch_size。
??score_1?
?的shape是 (batch_size, 2 x batch_size) 針對最后一個維度有2 x batch_size而言,前面一個batch_size是score_1,后面一個batch_size是score_2,所有target_2 的值相比 target_1 要再加 batch_size。
檢索器使用
??KG-LLM-MDQA\Pipeline\retriever.py?
? 大家可以看一下這個腳本中,在做向量召回的時候,使用的召回方法絕大多數(shù)都是TF-IDF,那這個言外之意就是前面檢索器的微調(diào)效果不好。那豈不是前面微調(diào)了半天的檢索器,白微調(diào)了。論文的實驗結(jié)果中,效果比較好的KGP_T5方法使用的檢索器 ??llm_retriever_KG_T5?
? 也是用的 TF-IDF。
class KG_retriever(object):
def __init__(self, k):
self.k = k
def retrieve(self, data, G):
corpus = [c for _, c in data['title_chunks']]
candidates_idx = list(range(len(corpus)))
seed = data['question']
retrieve_idxs = []
prev_length = 0
count = 0
retrieve_num = [10, 5, 5, 5, 3, 2, 2, 2, 2, 2, 2]
while len(retrieve_idxs) < self.k:
idxs = tf_idf(seed, candidates_idx, corpus, k = retrieve_num[count], visited = retrieve_idxs)
retrieve_idxs.extend(idxs[:max(0, self.k - len(retrieve_idxs))])
candidates_idx = set(chain(*[list(G.neighbors(node)) for node in idxs]))
candidates_idx = list(candidates_idx.difference(retrieve_idxs))
if len(retrieve_idxs) == prev_length:
break
else:
prev_length = len(retrieve_idxs)
count += 1
return [corpus[idx] for idx in retrieve_idxs], None, None, None
candidates_idx 候選的節(jié)點,利用 tf_idf 算法從候選節(jié)點中,找出新的候選節(jié)點。visited 表示已經(jīng)訪問過的節(jié)點,已經(jīng)訪問過的節(jié)點不再加入到新的候選節(jié)點中。如果新的候選節(jié)點為空,則停止節(jié)點召回。類似廣度優(yōu)先搜索,一層一層地往下搜索。retrieve_num 表示每一層要篩選的節(jié)點數(shù)量,第一層多取一點,下面的幾層少選一點。
大模型檢索微調(diào)
通過閱讀上述的提示詞,在微調(diào)大模型讓其學(xué)會根據(jù)問題生成相關(guān)支撐文本,再用生成的支撐文本做文本檢索召回。
論文名:Knowledge Graph Prompting for Multi-Document Question Answering
論文地址:https://arxiv.org/abs/2308.11730
源碼:https://github.com/YuWVandy/KG-LLM-MDQA
本文轉(zhuǎn)載自????AI悠閑區(qū)????,作者:jieshenai
