原创

Spring Boot集成LangChain来实现Rag应用

1.什么是rag?

检索增强生成(RAG)是指对大型语言模型输出进行优化,使其能够在生成响应之前引用训练数据来源之外的权威知识库。大型语言模型(LLM)用海量数据进行训练,使用数十亿个参数为回答问题、翻译语言和完成句子等任务生成原始输出。在 LLM 本就强大的功能基础上,RAG 将其扩展为能访问特定领域或组织的内部知识库,所有这些都无需重新训练模型。这是一种经济高效地改进 LLM 输出的方法,让它在各种情境下都能保持相关性、准确性和实用性。

为什么检索增强生成很重要?

LLM 是一项关键的人工智能(AI)技术,为智能聊天机器人和其他自然语言处理(NLP)应用程序提供支持。目标是通过交叉引用权威知识来源,创建能够在各种环境中回答用户问题的机器人。不幸的是,LLM 技术的本质在 LLM 响应中引入了不可预测性。此外,LLM 训练数据是静态的,并引入了其所掌握知识的截止日期。 LLM 面临的已知挑战包括:
  • 在没有答案的情况下提供虚假信息。
  • 当用户需要特定的当前响应时,提供过时或通用的信息。
  • 从非权威来源创建响应。
  • 由于术语混淆,不同的培训来源使用相同的术语来谈论不同的事情,因此会产生不准确的响应。
您可以将大型语言模型看作是一个过于热情的新员工,他拒绝随时了解时事,但总是会绝对自信地回答每一个问题。不幸的是,这种态度会对用户的信任产生负面影响,这是您不希望聊天机器人效仿的! RAG 是解决其中一些挑战的一种方法。它会重定向 LLM,从权威的、预先确定的知识来源中检索相关信息。组织可以更好地控制生成的文本输出,并且用户可以深入了解 LLM 如何生成响应。

检索增强生成的工作原理是什么?

如果没有 RAG,LLM 会接受用户输入,并根据它所接受训练的信息或它已经知道的信息创建响应。RAG 引入了一个信息检索组件,该组件利用用户输入首先从新数据源提取信息。用户查询和相关信息都提供给 LLM。LLM 使用新知识及其训练数据来创建更好的响应。以下各部分概述了该过程。

创建外部数据

LLM 原始训练数据集之外的新数据称为外部数据。它可以来自多个数据来源,例如 API、数据库或文档存储库。数据可能以各种格式存在,例如文件、数据库记录或长篇文本。另一种称为嵌入语言模型的 AI 技术将数据转换为数字表示形式并将其存储在向量数据库中。这个过程会创建一个生成式人工智能模型可以理解的知识库。

检索相关信息

下一步是执行相关性搜索。用户查询将转换为向量表示形式,并与向量数据库匹配。例如,考虑一个可以回答组织的人力资源问题的智能聊天机器人。如果员工搜索:“我有多少年假?”,系统将检索年假政策文件以及员工个人过去的休假记录。这些特定文件将被退回,因为它们与员工输入的内容高度相关。相关性是使用数学向量计算和表示法计算和建立的。

增强 LLM 提示

接下来,RAG 模型通过在上下文中添加检索到的相关数据来增强用户输入(或提示)。此步骤使用提示工程技术与 LLM 进行有效沟通。增强提示允许大型语言模型为用户查询生成准确的答案。

更新外部数据

下一个问题可能是——如果外部数据过时了怎么办? 要维护当前信息以供检索,请异步更新文档并更新文档的嵌入表示形式。您可以通过自动化实时流程或定期批处理来执行此操作。这是数据分析中常见的挑战——可以使用不同的数据科学方法进行变更管理。 下图显示了将 RAG 与 LLM 配合使用的概念流程。
jumpstart-fm-rag

2.什么是LangChain?

LangChain 是一个用于开发由语言模型驱动的应用程序的框架。他主要拥有 2 个能力:

  1. 可以将 LLM 模型与外部数据源进行连接
  2. 允许与 LLM 模型进行交互

LLM 模型:Large Language Model,大型语言模型

3.代码工程

实验目的

利用LangChain实现rag应用

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.1</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>rag</artifactId>


    <properties>
        <java.version>17</java.version>
        <langchain4j.version>0.23.0</langchain4j.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-thymeleaf</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-devtools</artifactId>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>dev.langchain4j</groupId>
            <artifactId>langchain4j</artifactId>
            <version>${langchain4j.version}</version>
        </dependency>
        <dependency>
            <groupId>dev.langchain4j</groupId>
            <artifactId>langchain4j-open-ai</artifactId>
            <version>${langchain4j.version}</version>
        </dependency>
        <dependency>
            <groupId>dev.langchain4j</groupId>
            <artifactId>langchain4j-embeddings</artifactId>
            <version>${langchain4j.version}</version>
        </dependency>
        <dependency>
            <groupId>dev.langchain4j</groupId>
            <artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
            <version>${langchain4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <excludes>
                        <exclude>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </exclude>
                    </excludes>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

controller

package com.et.rag.controller;

import com.et.rag.service.SBotService;
import lombok.RequiredArgsConstructor;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;

@Controller
@RequiredArgsConstructor
public class SBotController {

    private final SBotService sBotService;

    @GetMapping
    public String home() {
        return "index";
    }

    @PostMapping("/ask")
    public ResponseEntity<String> ask(@RequestBody String question) {
        try {
            return ResponseEntity.ok(sBotService.askQuestion(question));
        } catch (Exception e) {
            return ResponseEntity.badRequest().body("Sorry, I can't process your question right now.");
        }
    }
}

service

package com.et.rag.service;

import dev.langchain4j.chain.ConversationalRetrievalChain;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

@Service
@RequiredArgsConstructor
@Slf4j
public class SBotService {

    private final ConversationalRetrievalChain chain;

    public String askQuestion(String question) {
        log.debug("======================================================");
        log.debug("Question: " + question);
        String answer = chain.execute(question);
        log.debug("Answer: " + answer);
        log.debug("======================================================");
        return answer;
    }
}

EmbeddingStoreLoggingRetriever

package com.et.rag.retriever;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.retriever.Retriever;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.List;

/**
 * EmbeddingStoreLoggingRetriever is a logging-enhanced for an EmbeddingStoreRetriever.
 * <p>
 * This class logs the relevant TextSegments discovered by the supplied
 * EmbeddingStoreRetriever for improved transparency and debugging.
 * <p>
 * Logging happens at INFO level, printing each relevant TextSegment found
 * for a given input text once the findRelevant method is called.
 */

@RequiredArgsConstructor
@Slf4j
public class EmbeddingStoreLoggingRetriever implements Retriever<TextSegment> {

    private final EmbeddingStoreRetriever retriever;

    @Override
    public List<TextSegment> findRelevant(String text) {
        List<TextSegment> relevant = retriever.findRelevant(text);
        relevant.forEach(segment -> {
            log.debug("=======================================================");
            log.debug("Found relevant text segment: {}", segment);
        });
        return relevant;
    }
}

components

初始化documents
package com.et.rag.configuration;

import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.UrlDocumentLoader;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.List;

import static com.et.rag.constant.Constants.SPRING_BOOT_RESOURCES_LIST;


@Configuration
public class DocumentConfiguration {
    @Bean
    public List<Document> documents() {
        return SPRING_BOOT_RESOURCES_LIST.stream()
                .map(url -> {
                    try {
                        return UrlDocumentLoader.load(url);
                    } catch (Exception e) {
                        throw new RuntimeException("Failed to load document from " + url, e);
                    }
                })
                .toList();
    }
}
初始化langchain
package com.et.rag.configuration;

import com.et.rag.retriever.EmbeddingStoreLoggingRetriever;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.time.Duration;
import java.util.List;

import static com.et.rag.constant.Constants.PROMPT_TEMPLATE_2;


@Configuration
@RequiredArgsConstructor
@Slf4j
public class LangChainConfiguration {

    @Value("${langchain.api.key}")
    private String apiKey;

    @Value("${langchain.timeout}")
    private Long timeout;

    private final List<Document> documents;

    @Bean
    public ConversationalRetrievalChain chain() {

        EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();

        EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();

        EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
                .documentSplitter(DocumentSplitters.recursive(500, 0))
                .embeddingModel(embeddingModel)
                .embeddingStore(embeddingStore)
                .build();

        log.info("Ingesting Spring Boot Resources ...");

        ingestor.ingest(documents);

        log.info("Ingested {} documents", documents.size());

        EmbeddingStoreRetriever retriever = EmbeddingStoreRetriever.from(embeddingStore, embeddingModel);
        EmbeddingStoreLoggingRetriever loggingRetriever = new EmbeddingStoreLoggingRetriever(retriever);

        /*MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
                .maxMessages(10)
                .build();*/

        log.info("Building ConversationalRetrievalChain ...");

        ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
                .chatLanguageModel(OpenAiChatModel.builder()
                        .apiKey(apiKey)
                        .timeout(Duration.ofSeconds(timeout))
                        .build()
                )
                .promptTemplate(PromptTemplate.from(PROMPT_TEMPLATE_2))
                //.chatMemory(chatMemory)
                .retriever(loggingRetriever)
                .build();

        log.info("Spring Boot knowledge base is ready!");
        return chain;
    }
}
langchain:
  api:
    # "demo" is a free API key for testing purposes only. Please replace it with your own API key.
    key: demo
    # key: OPEN_API_KEY
  # API call to complete before it is timed out.
  timeout: 30
<!DOCTYPE html>
<html lang="en"
      xmlns="http://www.w3.org/1999/xhtml">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>Spring Boot Doc Bot</title>
    <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet">
    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">
</head>
<body>
<nav class="bg-dark text-white py-3">
    <div class="text-center d-flex justify-content-center align-items-center">
        <img src="/logo.png" alt="Logo" style="width:60px; margin-right: 10px;">
        <h2 style="margin: 0;">Welcome to Spring Boot Documentation Bot</h2>
    </div>
</nav>
<div class="container mt-5">
    <div class="row">
        <div class="col-md-8 offset-2">
            <h3 class="text-center mb-3">Ask your Spring related queries here!</h3>
            <form>
                <div class="mb-3">
                    <label for="questionInput" class="form-label">Question</label>
                    <input type="text" class="form-control" id="questionInput" name="question" placeholder="Enter your question" required>
                </div>
                <div class="mb-3 text-center">
                    <button id="submitBtn" type="button" class="btn btn-primary">Ask!</button>
                    <button id="clearBtn" type="button" class="btn btn-secondary">Clear</button>
                </div>
            </form>
        </div>
    </div>
    <div class="row my-5">
        <div class="col-md-8 offset-md-2">
            <label for="answerBox" class="form-label"><h5>Answer</h5></label>
            <div class="position-relative my-3">
                <textarea class="form-control" rows="10" id="answerBox" disabled></textarea>
                <a href="#" class="position-absolute top-0 end-0 m-2" id="copyBtn">
                    <i class="far fa-copy"></i>
                </a>
            </div>
        </div>
    </div>
</div>
<script src="https://code.jquery.com/jquery-3.7.1.min.js"></script>
<script>
    $(document).ready(function () {
        $("#submitBtn").click(function () {
            let questionValue = $("#questionInput").val();
            if (!questionValue) {
                alert('Please enter your question');
                return;
            }

            $("#answerBox").val('Please wait... fetching answer...');

            $.ajax({
                type: "POST",
                url: "/ask",
                data: JSON.stringify({ question: $("#questionInput").val() }),
                //contentType: "application/json; charset=utf-8",
                dataType: "text",
                success: function (data) {
                    //console.log(typeof data);
                    //console.log(data);
                    $("#answerBox").val(data);
                },
                error: function (errMsg) {
                    alert(errMsg);
                }
            });
        });

        $("#clearBtn").click(function () {
            $("#questionInput").val('');
            $("#answerBox").val('');
        });

        document.getElementById("copyBtn").addEventListener("click", function() {
            var copyText = document.getElementById("answerBox");

            copyText.select();
            copyText.setSelectionRange(0, 99999);

            document.execCommand("copy");

            alert("Copied: " + copyText.value);
        });
    });
</script>
</body>
</html>

4.测试

启动Spring Boot应用,访问http://127.0.0.1:8080/ rag

5.引用

正文到此结束
Loading...