← 返回首页
手搓一个RAG
发表时间:2025-05-25 02:23:27
手搓一个RAG

1.手搓一个RAG

包含以下步骤:

# 1、读取文档
with open("黑悟空.txt", encoding="utf-8") as f:
    raw_text = f.read()
# print(raw_text)

import re
# 2、分割
def splitter(text, chunk_size=200, chunk_overlap=20, separator=""):
    splites = re.split(separator, text)
    # print(splites)
    total = 0  # 当前块的累计长度
    current_doc = []  # 存储当前正在构建的块的片段列表
    docs = []  # 存储所有块的列表  [[chunk1],[chunk2],...]
    separator_len = len(separator)
    for s in splites:
        _len = len(s)
        if total + _len + (separator_len if len(current_doc) > 0 else 0) > chunk_size:
            if current_doc: # 如果current_doc不为空, 则将其连接成一个字符串并添加在docs列表,表示一个块已完成。
                docs.append(separator.join(current_doc))
            # 处理重叠, 从current_doc的开头移除片段,直到满足重叠需求
            while total > chunk_overlap or (total + _len + (separator_len if len(current_doc) > 0 else 0) > chunk_size and total > 0): 
                removed_len = len(current_doc[0])
                total -= removed_len
                current_doc.pop(0)

        # 将当前片段s添加到当前块中,需要更新total的长度
        current_doc.append(s)
        total += _len

    # current_doc剩余的片段,添加进docs
    if current_doc:
        docs.append(separator.join(current_doc))
    return docs

texts = splitter(raw_text)
# print(texts)


# 3、向量化
from sentence_transformers import SentenceTransformer
embedding = SentenceTransformer("models/AI-ModelScope/bge-large-zh-v1___5")
embeddings = embedding.encode(texts)

import os
import json
# 4、放入数据库,只会执行一次
# {texts: embeddings}
def persist(texts, embeddings, path="database"):
    if not os.path.exists(path): # 如果不存在,则创建
        os.makedirs(path)
    data = {} # 将文本和向量存储成字典
    for i, text in enumerate(texts):
        data[text] = embeddings[i].tolist()  # 将text作为键,将向量作为值

    with open(f"{path}/text_embeddings.json", "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

#persist(texts, embeddings)


import numpy as np
# 5、检索
def search(query, database="database/text_embeddings.json", k=5):
    query_vector = embedding.encode([query])[0].tolist()
    # print(query_vector)

    with open(database, "r", encoding="utf-8") as f:
        loaded_data = json.load(f)

    # 计算相似度
    similarities = []  # 用于存储查询向量与数据库中向量的相似度
    for text, vector in loaded_data.items():
        similarity  = np.dot(query_vector, vector)  # 点积
        similarities.append((text, similarity))

    similarities.sort(key=lambda x: x[1], reverse=True)
    return similarities[:k]

contexts = search("黑熊精自称为什么?")
# print(contexts)
context = "".join(c[0] for c in contexts)

# 6、调用LLM
from langchain_openai import ChatOpenAI
chat_model = ChatOpenAI(
    api_key="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
    base_url="https://api.siliconflow.cn/v1",
    model="Qwen/Qwen2.5-7B-Instruct"
)

# 7、定义消息
messages = [
    {"role":"system", "content": f"根据已知内容:{context},回答用户问题。"},
    {"role":"user", "content": f"问题:黑熊精自称为什么?"}
]

# 8、输出
output = chat_model.invoke(messages).content

print(output)

运行效果:

黑熊精自称黑风大王。