LangChain 是一个用于开发由语言模型驱动的应用程序的框架。他主要拥有 2 个能力:
LLM 模型:Large Language Model,大型语言模型
利用LangChain实现rag应用
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.2.1</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>rag</artifactId>
<properties>
<java.version>17</java.version>
<langchain4j.version>0.23.0</langchain4j.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-thymeleaf</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
package com.et.rag.controller;
import com.et.rag.service.SBotService;
import lombok.RequiredArgsConstructor;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
@Controller
@RequiredArgsConstructor
public class SBotController {
private final SBotService sBotService;
@GetMapping
public String home() {
return "index";
}
@PostMapping("/ask")
public ResponseEntity<String> ask(@RequestBody String question) {
try {
return ResponseEntity.ok(sBotService.askQuestion(question));
} catch (Exception e) {
return ResponseEntity.badRequest().body("Sorry, I can't process your question right now.");
}
}
}
package com.et.rag.service;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
@Slf4j
public class SBotService {
private final ConversationalRetrievalChain chain;
public String askQuestion(String question) {
log.debug("======================================================");
log.debug("Question: " + question);
String answer = chain.execute(question);
log.debug("Answer: " + answer);
log.debug("======================================================");
return answer;
}
}
package com.et.rag.retriever;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.retriever.Retriever;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
/**
* EmbeddingStoreLoggingRetriever is a logging-enhanced for an EmbeddingStoreRetriever.
* <p>
* This class logs the relevant TextSegments discovered by the supplied
* EmbeddingStoreRetriever for improved transparency and debugging.
* <p>
* Logging happens at INFO level, printing each relevant TextSegment found
* for a given input text once the findRelevant method is called.
*/
@RequiredArgsConstructor
@Slf4j
public class EmbeddingStoreLoggingRetriever implements Retriever<TextSegment> {
private final EmbeddingStoreRetriever retriever;
@Override
public List<TextSegment> findRelevant(String text) {
List<TextSegment> relevant = retriever.findRelevant(text);
relevant.forEach(segment -> {
log.debug("=======================================================");
log.debug("Found relevant text segment: {}", segment);
});
return relevant;
}
}
package com.et.rag.configuration;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.UrlDocumentLoader;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
import static com.et.rag.constant.Constants.SPRING_BOOT_RESOURCES_LIST;
@Configuration
public class DocumentConfiguration {
@Bean
public List<Document> documents() {
return SPRING_BOOT_RESOURCES_LIST.stream()
.map(url -> {
try {
return UrlDocumentLoader.load(url);
} catch (Exception e) {
throw new RuntimeException("Failed to load document from " + url, e);
}
})
.toList();
}
}
初始化langchain
package com.et.rag.configuration;
import com.et.rag.retriever.EmbeddingStoreLoggingRetriever;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.time.Duration;
import java.util.List;
import static com.et.rag.constant.Constants.PROMPT_TEMPLATE_2;
@Configuration
@RequiredArgsConstructor
@Slf4j
public class LangChainConfiguration {
@Value("${langchain.api.key}")
private String apiKey;
@Value("${langchain.timeout}")
private Long timeout;
private final List<Document> documents;
@Bean
public ConversationalRetrievalChain chain() {
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
.documentSplitter(DocumentSplitters.recursive(500, 0))
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
log.info("Ingesting Spring Boot Resources ...");
ingestor.ingest(documents);
log.info("Ingested {} documents", documents.size());
EmbeddingStoreRetriever retriever = EmbeddingStoreRetriever.from(embeddingStore, embeddingModel);
EmbeddingStoreLoggingRetriever loggingRetriever = new EmbeddingStoreLoggingRetriever(retriever);
/*MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
.maxMessages(10)
.build();*/
log.info("Building ConversationalRetrievalChain ...");
ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
.chatLanguageModel(OpenAiChatModel.builder()
.apiKey(apiKey)
.timeout(Duration.ofSeconds(timeout))
.build()
)
.promptTemplate(PromptTemplate.from(PROMPT_TEMPLATE_2))
//.chatMemory(chatMemory)
.retriever(loggingRetriever)
.build();
log.info("Spring Boot knowledge base is ready!");
return chain;
}
}
langchain:
api:
# "demo" is a free API key for testing purposes only. Please replace it with your own API key.
key: demo
# key: OPEN_API_KEY
# API call to complete before it is timed out.
timeout: 30
<!DOCTYPE html>
<html lang="en"
xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Spring Boot Doc Bot</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">
</head>
<body>
<nav class="bg-dark text-white py-3">
<div class="text-center d-flex justify-content-center align-items-center">
<img src="/logo.png" alt="Logo" style="width:60px; margin-right: 10px;">
<h2 style="margin: 0;">Welcome to Spring Boot Documentation Bot</h2>
</div>
</nav>
<div class="container mt-5">
<div class="row">
<div class="col-md-8 offset-2">
<h3 class="text-center mb-3">Ask your Spring related queries here!</h3>
<form>
<div class="mb-3">
<label for="questionInput" class="form-label">Question</label>
<input type="text" class="form-control" id="questionInput" name="question" placeholder="Enter your question" required>
</div>
<div class="mb-3 text-center">
<button id="submitBtn" type="button" class="btn btn-primary">Ask!</button>
<button id="clearBtn" type="button" class="btn btn-secondary">Clear</button>
</div>
</form>
</div>
</div>
<div class="row my-5">
<div class="col-md-8 offset-md-2">
<label for="answerBox" class="form-label"><h5>Answer</h5></label>
<div class="position-relative my-3">
<textarea class="form-control" rows="10" id="answerBox" disabled></textarea>
<a href="#" class="position-absolute top-0 end-0 m-2" id="copyBtn">
<i class="far fa-copy"></i>
</a>
</div>
</div>
</div>
</div>
<script src="https://code.jquery.com/jquery-3.7.1.min.js"></script>
<script>
$(document).ready(function () {
$("#submitBtn").click(function () {
let questionValue = $("#questionInput").val();
if (!questionValue) {
alert('Please enter your question');
return;
}
$("#answerBox").val('Please wait... fetching answer...');
$.ajax({
type: "POST",
url: "/ask",
data: JSON.stringify({ question: $("#questionInput").val() }),
//contentType: "application/json; charset=utf-8",
dataType: "text",
success: function (data) {
//console.log(typeof data);
//console.log(data);
$("#answerBox").val(data);
},
error: function (errMsg) {
alert(errMsg);
}
});
});
$("#clearBtn").click(function () {
$("#questionInput").val('');
$("#answerBox").val('');
});
document.getElementById("copyBtn").addEventListener("click", function() {
var copyText = document.getElementById("answerBox");
copyText.select();
copyText.setSelectionRange(0, 99999);
document.execCommand("copy");
alert("Copied: " + copyText.value);
});
});
</script>
</body>
</html>