Retrieval-Augmented Generation (RAG) 系统实现
Zhongjun Qiu 元婴开发者

本篇文章实现了一个 从零开始构建的多模态检索增强生成(RAG)系统,融合了 BM25 语义检索 + Dense 向量检索 + Cross-Encoder 重排序 + LLM 回答生成 的完整流程。

目标是模拟真实生产中的 RAG 管线,实现一个可复用的、轻量级的本地 RAG 框架。

🎯 系统目标

该系统旨在实现一个可复现的端到端 知识增强问答 管道,包括:

  1. 文档摄入(Document Ingestion)
    • 自动读取指定目录下的 .txt 文件;
    • 使用句子分割器将文档切分成固定长度的文本块(chunks)。
  2. 多通道索引构建
    • 使用 BM25 进行关键词倒排索引;
    • 使用 SentenceTransformer 将 chunk 编码为 dense 向量;
    • 使用 FAISS 构建向量索引;
    • 可混合 BM25 与 dense 召回(Hybrid Search)。
  3. 查询与多跳检索(Multi-Hop Retrieval)
    • 首次召回使用 Hybrid 检索;
    • 通过 LLM(如 Qwen)生成下一步检索 query;
    • 迭代式地进行多跳检索,获取更丰富上下文。
  4. 重排序与答案生成
    • 使用 Cross-Encoder 对候选段落进行相关性重排序;
    • 拼接上下文与用户问题形成 prompt;
    • 调用本地或 OpenAI LLM 生成最终回答。

🧩 系统流程图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
flowchart TD

A[📂 文档目录 data/] --> B[🔹 读取与分句 chunking]
B --> C1[🔸 BM25 索引构建]
B --> C2[🔸 Dense Embedding 编码 SentenceTransformer]
C2 --> D[🧮 FAISS 向量索引]
C1 --> E[⚡ Hybrid 检索 BM25 + FAISS]
D --> E
E --> F[🧠 Cross-Encoder 重排序]
F --> G[🔁 LLM 生成下一跳查询 Multi-Hop]
G --> E
F --> H[📚 拼接上下文 & 问题 → Prompt]
H --> I[🤖 LLM 生成最终答案]
I --> J[✅ 输出 Answer]
````

---

## 运行与配置说明

### 🔧 环境依赖

确保已安装以下主要库:

```bash
pip install -r requirements.txt

若网络访问 HuggingFace 较慢,可启用镜像:

1
2
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

⚙️ 主要配置参数(Config 类)

参数 含义 默认值
data_dir 文档目录路径 ./data
chunk_size_sentences 每个 chunk 的句子数 5
dense_model_name Dense encoder 模型 sentence-transformers/all-MiniLM-L6-v2
reranker_model_name Cross-Encoder 模型 cross-encoder/ms-marco-MiniLM-L-6-v2
hf_llm_model 回答生成模型 Qwen/Qwen2.5-7B-Instruct
top_k_recall 召回阶段候选数量 50
top_k_rerank 重排序后取前几段 10

🚀 使用流程

  1. 准备文本数据:

    • data/ 文件夹中放入若干 .txt 文档。
  2. 运行主程序:

    1
    main()
  3. 系统流程:

    • 自动构建索引;
    • 执行多跳检索;
    • 拼接上下文生成答案;
    • 输出最终回答。

💡 提示:

  • 若显存有限,可修改 device="cpu" 或使用 4-bit 量化配置;
  • 若希望使用 OpenAI 接口,可设置 SILICON_API_KEY 环境变量;
  • Multi-Hop 检索层数可在 multi.retrieve(query, hops=N) 中调整。

详细代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

import os
import glob
import json
import math
from typing import List, Tuple, Dict, Any
from dataclasses import dataclass

# 文本处理相关
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
nltk.download('punkt_tab')

# TF-IDF/BM25
from rank_bm25 import BM25Okapi

# Dense embeddings
from sentence_transformers import SentenceTransformer, CrossEncoder
import numpy as np
import faiss

# LLM(这里提供两种调用方式:OpenAI 或 HuggingFace)
# import openai
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

CURRENT_PATH = "..."

# ------------- 配置 -------------
from transformers import BitsAndBytesConfig


@dataclass
class Config:
data_dir: str = f"{CURRENT_PATH}/data"
chunk_size_sentences: int = 5
chunk_overlap: int = 1
dense_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
reranker_model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
hf_llm_model: str = "Qwen/Qwen2.5-7B-Instruct"
openai_api_key: str = os.getenv('SILICON_API_KEY', '')
top_k_recall: int = 5
top_k_rerank: int = 2
faiss_index_path: str = "faiss.index"

cfg = Config()

# ------------- 工具函数:文档摄入与 chunk -------------

def read_texts_from_dir(data_dir: str) -> List[Tuple[str, str]]:
"""读取目录下所有 txt 文件,返回 (doc_id, text) 列表"""
files = glob.glob(os.path.join(data_dir, "*.txt"))
docs = []
for p in files:
doc_id = os.path.basename(p)
with open(p, 'r', encoding='utf-8') as f:
text = f.read()
docs.append((doc_id, text))
return docs


def chunk_document(doc_id: str, text: str, chunk_size_sentences=5, overlap=1) -> List[Dict[str, Any]]:
sents = sent_tokenize(text)
chunks = []
i = 0
chunk_id = 0
while i < len(sents):
end = min(i + chunk_size_sentences, len(sents))
chunk_text = " ".join(sents[i:end])
chunks.append({
'doc_id': doc_id,
'chunk_id': f"{doc_id}_chunk{chunk_id}",
'text': chunk_text,
'start_sent': i,
'end_sent': end
})
chunk_id += 1
i += chunk_size_sentences - overlap
return chunks

# ------------- 索引构建:BM25 + Dense (FAISS) -------------

class Indexer:
def __init__(self, cfg: Config):
self.cfg = cfg
self.dense_model = SentenceTransformer(cfg.dense_model_name, device="cpu")
self.reranker = CrossEncoder(cfg.reranker_model_name, device="cpu")

# 保存 chunk 列表
self.chunks: List[Dict[str, Any]] = []

# BM25 结构
self.bm25 = None
self.bm25_tokenized = []

# FAISS
self.faiss_index = None
self.faiss_id_map = [] # idx -> chunk idx

def ingest(self, docs: List[Tuple[str, str]]):
# chunk 并保存
for doc_id, text in docs:
chs = chunk_document(doc_id, text, self.cfg.chunk_size_sentences, self.cfg.chunk_overlap)
self.chunks.extend(chs)

print(f"Total chunks: {len(self.chunks)}")
print(f"The first 5 chunks:")
for c in self.chunks[:5]:
print(f"- {c['chunk_id']}: {c['text']}...")

# 构建 BM25
tokenized = [self._tokenize_for_bm25(c['text']) for c in self.chunks]
self.bm25_tokenized = tokenized
self.bm25 = BM25Okapi(tokenized)

# 构建 dense 向量
texts = [c['text'] for c in self.chunks]
embeddings = self.dense_model.encode(texts, convert_to_numpy=True, show_progress_bar=True)

d = embeddings.shape[1]
index = faiss.IndexFlatIP(d)
faiss.normalize_L2(embeddings)
index.add(embeddings) # type: ignore
self.faiss_index = index
self.faiss_id_map = list(range(len(self.chunks)))

def _tokenize_for_bm25(self, text: str) -> List[str]:
# 简单分词:按空格 + 小写,可替换为更复杂的分词器
return [t.lower() for t in text.split()]

# BM25 召回
def bm25_retrieve(self, query: str, top_k: int) -> List[Tuple[int, float]]:
q_tokens = self._tokenize_for_bm25(query)
scores = self.bm25.get_scores(q_tokens) if self.bm25 != None else []
idxs = np.argsort(scores)[::-1][:top_k]
return [(int(i), float(scores[i])) for i in idxs]

# Dense 召回
def dense_retrieve(self, query: str, top_k: int) -> List[Tuple[int, float]]:
q_emb = self.dense_model.encode([query], convert_to_numpy=True)
faiss.normalize_L2(q_emb)
D, I = self.faiss_index.search(q_emb, top_k) # type: ignore
return [(int(I[0][i]), float(D[0][i])) for i in range(len(I[0]))]

# Hybrid: 合并 BM25 + Dense 得分(简单加权)
def hybrid_retrieve(self, query: str, top_k: int, alpha=0.5) -> List[Tuple[int, float]]:
bm = dict(self.bm25_retrieve(query, top_k*2))
de = dict(self.dense_retrieve(query, top_k*2))
# normalize
all_ids = set(list(bm.keys()) + list(de.keys()))
bm_vals = np.array([bm.get(i, 0.0) for i in all_ids])
de_vals = np.array([de.get(i, 0.0) for i in all_ids])
if bm_vals.max() > 0:
bm_vals = bm_vals / (bm_vals.max())
if de_vals.max() > 0:
de_vals = de_vals / (de_vals.max())
scores = {}
for idx, b_norm, d_norm in zip(list(all_ids), bm_vals, de_vals):
scores[idx] = alpha * b_norm + (1 - alpha) * d_norm
top = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
return top

# Rerank 使用 cross-encoder(更强模型,输入 query + doc)
def rerank(self, query: str, candidates: List[int], top_k: int) -> List[Tuple[int, float]]:
pairs = [(query, self.chunks[c]['text']) for c in candidates]
scores = self.reranker.predict(pairs)
scored = list(zip(candidates, scores))
scored_sorted = sorted(scored, key=lambda x: x[1], reverse=True)[:top_k]
return scored_sorted

# ------------- 多跳检索(简单示例) -------------

class MultiHopRAG:
def __init__(self, indexer: Indexer, cfg: Config):
self.indexer = indexer
self.cfg = cfg
# LLM for query rewriting / next-hop generation
# 这里提供简单的 HF 生成器作为示例
self.llm_tokenizer = AutoTokenizer.from_pretrained(cfg.hf_llm_model)
self.llm_model = AutoModelForCausalLM.from_pretrained(cfg.hf_llm_model)
# self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = 'cpu'
self.llm_model.to(self.device) # type: ignore

def generate_followup_query(self, query: str, context_chunks: List[str]) -> str:
# 简单的 prompt:让 LLM 根据当前检索结果抽出下一步检索关键词
prompt = "Given the user question and retrieved contexts, generate a concise follow-up search query to fetch more evidence.\n\nQuestion:\n" + query + "\n\nContexts:\n"
for c in context_chunks:
prompt += "- " + c + "\n"
prompt += "\nFollow-up query:"
inputs = self.llm_tokenizer(prompt, return_tensors='pt').to(self.device)
outputs = self.llm_model.generate(**inputs, max_new_tokens=32)
ans = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
# 取最后一行或截取
print()
print(prompt)
print(ans)
print()
return ans.split('\n')[-1].strip()

def retrieve(self, query: str, hops=2) -> List[Dict[str, Any]]:
# 第一步 hybrid 召回
hybrid = self.indexer.hybrid_retrieve(query, self.cfg.top_k_recall, alpha=0.5)
cand_ids = [c for c, s in hybrid]
# 重排序取 top_k_rerank
reranked = self.indexer.rerank(query, cand_ids, self.cfg.top_k_rerank)
top_chunks = [self.indexer.chunks[c] for c, s in reranked]

all_results = top_chunks.copy()

# 迭代多跳
current_query = query
for hop in range(1, hops):
contexts = [c['text'] for c in top_chunks]
followup = self.generate_followup_query(current_query, contexts)
hybrid2 = self.indexer.hybrid_retrieve(followup, self.cfg.top_k_recall, alpha=0.5)
cand_ids2 = [c for c, s in hybrid2]
reranked2 = self.indexer.rerank(followup, cand_ids2, self.cfg.top_k_rerank)
top_chunks2 = [self.indexer.chunks[c] for c, s in reranked2]
all_results.extend(top_chunks2)
# 为下一轮聚焦
top_chunks = top_chunks2
current_query = followup

return all_results

# ------------- Prompt 拼接 与 最终生成 -------------

class RAGPipeline:
def __init__(self, indexer: Indexer, cfg: Config):
self.indexer = indexer
self.cfg = cfg
# LLM 选择:OpenAI 或 HuggingFace
self.use_openai = False if cfg.openai_api_key == '' else True
if not self.use_openai:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
self.hf_tokenizer = AutoTokenizer.from_pretrained(cfg.hf_llm_model)
self.hf_model = AutoModelForCausalLM.from_pretrained(
cfg.hf_llm_model,
device_map="auto",
quantization_config=quant_config
)# type: ignore
# 'cuda' if torch.cuda.is_available() else 'cpu'

def build_prompt(self, question: str, contexts: List[Dict[str, Any]]) -> str:
prompt = "You are an expert assistant. Use the provided contexts to answer the question. If the answer is not contained in the contexts, say you don't know.\n\n"
prompt += "CONTEXTS:\n"
for i, c in enumerate(contexts):
prompt += f"[{i}] (source: {c['doc_id']}) {c['text']}\n---\n"
prompt += "\nQuestion:\n" + question + "\n\nAnswer:"
return prompt

def generate_answer(self, question: str, contexts: List[Dict[str, Any]], max_tokens=256):
prompt = self.build_prompt(question, contexts)
# 将 prompt 截断至模型最大输入长度
inputs = self.hf_tokenizer(prompt, return_tensors='pt')
# max_model_len = 1024
# if inputs.input_ids.shape[1] > max_model_len:
# inputs.input_ids = inputs.input_ids[:, -max_model_len:]
# inputs.attention_mask = inputs.attention_mask[:, -max_model_len:]
if self.use_openai:
import openai
openai.api_key = self.cfg.openai_api_key
res = openai.chat.completions.create(
model='deepseek-ai/DeepSeek-OCR',
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=0.0
)
return res
else:
inputs = self.hf_tokenizer(prompt, return_tensors='pt').to(self.hf_model.device)
outputs = self.hf_model.generate(**inputs, max_new_tokens=max_tokens, do_sample=False)
ans = self.hf_tokenizer.decode(outputs[0], skip_special_tokens=True)
# 取 prompt 之后的内容
print()
print(ans)
print()
return ans[len(prompt):].strip()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# ------------- Demo 主流程 -------------

def main():
docs = read_texts_from_dir(cfg.data_dir)
idx = Indexer(cfg)
print(f"Ingest {len(docs)} docs...")
idx.ingest(docs)

multi = MultiHopRAG(idx, cfg)
pipeline = RAGPipeline(idx, cfg)

# 示例查询
query = "who is ljc?"
retrieved = multi.retrieve(query, hops=2)
# 去重并限制数量
seen = set()
contexts = []
for r in retrieved:
if r['chunk_id'] in seen:
continue
seen.add(r['chunk_id'])
contexts.append(r)
if len(contexts) >= 4:
break

ans = pipeline.generate_answer(query, contexts)
print("\n---- ANSWER ----\n")
print(ans)
 REWARD AUTHOR
 Comments
Comment plugin failed to load
Loading comment plugin