第二章:advisor 增强
- 作者:影子, Spring AI Alibaba Committer
- 本文档基于 Spring AI 1.0.0 版本,Spring AI Alibaba 1.0.0.2 版本
- 本章包含快速上手(基于内存、sqlite、mysql、redis的历史消息存储)+ 源码解读(advisor基础、BaseChatMemoryAdvisor解读、AdvisorChain链)
基于内存的消息存储快速上手
用于在 AI 模型的请求和响应流程中插入自定义逻辑。实战代码可见:https://github.com/GTyingzi/spring-ai-tutorial 下的advisor目录
以下实现了 advisor 中有基于内存的历史消息存储的 chat 交互
pom 文件
<dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-openai</artifactId> </dependency>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId> </dependency>
</dependencies>
application.yml
server: port: 8080
spring: application: name: advisor-base
ai: openai: api-key: ${DASHSCOPEAPIKEY} base-url: https://dashscope.aliyuncs.com/compatible-mode chat: options: model: qwen-max
controller
MemoryMessageAdvisorController
package com.spring.ai.tutorial.advisor.controller;
import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;import org.springframework.ai.chat.memory.MessageWindowChatMemory;import org.springframework.ai.chat.messages.Message;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;
import java.util.List;
import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATIONID;
@RestController@RequestMapping("/advisor/memory/message")public class MemoryMessageAdvisorController {
private final ChatClient chatClient; private final InMemoryChatMemoryRepository chatMemoryRepository = new InMemoryChatMemoryRepository(); private final int MAXMESSAGES = 100; private final MessageWindowChatMemory messageWindowChatMemory = MessageWindowChatMemory.builder() .chatMemoryRepository(chatMemoryRepository) .maxMessages(MAXMESSAGES) .build();
public MemoryMessageAdvisorController(ChatClient.Builder builder) { this.chatClient = builder .defaultAdvisors( MessageChatMemoryAdvisor.builder(messageWindowChatMemory) .build() ) .build(); }
@GetMapping("/call") public String call(@RequestParam(value = "query", defaultValue = "你好,我的外号是影子,请记住呀") String query, @RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId ) { return chatClient.prompt(query) .advisors( a -> a.param(CONVERSATIONID, conversationId) ) .call().content(); }
@GetMapping("/messages") public List<Message> messages(@RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId) { return messageWindowChatMemory.get(conversationId); }
}
效果
以会话 Id=“yingzi”,先告知模型我的名字
再以同一个会话 Id=“yingzi”,模型能根据以往的消息记住了我的名字
获取历史消息记录,我们能得到历史消息记录
MemoryPromptAdvisorController
package com.spring.ai.tutorial.advisor.controller;
import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;import org.springframework.ai.chat.memory.MessageWindowChatMemory;import org.springframework.ai.chat.messages.Message;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;
import java.util.List;
import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATIONID;
@RestController@RequestMapping("/advisor/memory/prompt")public class MemoryPromptAdvisorController {
private final ChatClient chatClient; private final InMemoryChatMemoryRepository chatMemoryRepository = new InMemoryChatMemoryRepository(); private final int MAXMESSAGES = 100; private final MessageWindowChatMemory messageWindowChatMemory = MessageWindowChatMemory.builder() .chatMemoryRepository(chatMemoryRepository) .maxMessages(MAXMESSAGES) .build();
public MemoryPromptAdvisorController(ChatClient.Builder builder) { this.chatClient = builder .defaultAdvisors( PromptChatMemoryAdvisor.builder(messageWindowChatMemory) .build() ) .build(); }
@GetMapping("/call") public String call(@RequestParam(value = "query", defaultValue = "你好,我的外号是影子,请记住呀") String query, @RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId ) { return chatClient.prompt(query) .advisors( a -> a.param(CONVERSATIONID, conversationId) ) .call().content(); }
@GetMapping("/messages") public List<Message> messages(@RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId) { return messageWindowChatMemory.get(conversationId); }}
效果
以会话 Id=“yingzi”,先告知模型我的名字
再以同一个会话 Id=“yingzi”,模型能根据以往的消息记住了我的名字
获取历史消息记录,我们能得到历史消息记录
(增强)基于 sqlite、mysql、redis 的消息存储
实现了基于 sqlite、mysql、redis 的消息存储
pom 文件
<properties> <sqlite.version>3.49.1.0</sqlite.version> <mysql.version>8.0.32</mysql.version> <jedis.version>5.2.0</jedis.version></properties>
<dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-openai</artifactId> </dependency>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId> </dependency>
<dependency> <groupId>com.alibaba.cloud.ai</groupId> <artifactId>spring-ai-alibaba-starter-memory</artifactId> </dependency>
<dependency> <groupId>com.alibaba.cloud.ai</groupId> <artifactId>spring-ai-alibaba-starter-memory-jdbc</artifactId> </dependency>
<dependency> <groupId>com.alibaba.cloud.ai</groupId> <artifactId>spring-ai-alibaba-starter-memory-redis</artifactId> </dependency>
<dependency> <groupId>org.xerial</groupId> <artifactId>sqlite-jdbc</artifactId> <version>${sqlite.version}</version> </dependency>
<dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>${mysql.version}</version> </dependency>
<dependency> <groupId>redis.clients</groupId> <artifactId>jedis</artifactId> <version>${jedis.version}</version> </dependency>
</dependencies>
application.yml
server: port: 8080
spring: application: name: advisor-memory-mysql
ai: openai: api-key: ${DASHSCOPEAPIKEY} base-url: https://dashscope.aliyuncs.com/compatible-mode chat: options: model: qwen-max
chat: memory: repository: jdbc: mysql: jdbc-url: jdbc:mysql://localhost:3306/spring_ai_alibaba_mysql?useUnicode=true&characterEncoding=utf-8&useSSL=false&allowPublicKeyRetrieval=true&zeroDateTimeBehavior=convertToNull&transformedBitIsBoolean=true&allowMultiQueries=true&tinyInt1isBit=false&allowLoadLocalInfile=true&allowLocalInfile=true&allowUrl username: root password: root driver-class-name: com.mysql.cj.jdbc.Driver enabled: true
memory: redis: host: localhost port: 6379 timeout: 5000 password:
Sqlite
SqliteMemoryConfig
package com.spring.ai.tutorial.advisor.memory.config;
import com.alibaba.cloud.ai.memory.jdbc.SQLiteChatMemoryRepository;import org.springframework.context.annotation.Bean;import org.springframework.context.annotation.Configuration;import org.springframework.jdbc.core.JdbcTemplate;import org.springframework.jdbc.datasource.DriverManagerDataSource;
@Configurationpublic class SqliteMemoryConfig {
@Bean public SQLiteChatMemoryRepository sqliteChatMemoryRepository() { DriverManagerDataSource dataSource = new DriverManagerDataSource(); dataSource.setDriverClassName("org.sqlite.JDBC"); dataSource.setUrl("jdbc:sqlite:advisor/advisor-memory-sqlite/src/main/resources/chat-memory.db"); JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); return SQLiteChatMemoryRepository.sqliteBuilder() .jdbcTemplate(jdbcTemplate) .build(); }}
SqliteMemoryController
package com.spring.ai.tutorial.advisor.memory.controller;
import com.alibaba.cloud.ai.memory.jdbc.SQLiteChatMemoryRepository;import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;import org.springframework.ai.chat.memory.MessageWindowChatMemory;import org.springframework.ai.chat.messages.Message;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;
import java.util.List;
import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATIONID;
@RestController@RequestMapping("/advisor/memory/sqlite")public class SqliteMemoryController {
private final ChatClient chatClient; private final int MAXMESSAGES = 100; private final MessageWindowChatMemory messageWindowChatMemory;
public SqliteMemoryController(ChatClient.Builder builder, SQLiteChatMemoryRepository sqliteChatMemoryRepository) { this.messageWindowChatMemory = MessageWindowChatMemory.builder() .chatMemoryRepository(sqliteChatMemoryRepository) .maxMessages(MAXMESSAGES) .build();
this.chatClient = builder .defaultAdvisors( MessageChatMemoryAdvisor.builder(messageWindowChatMemory) .build() ) .build(); }
@GetMapping("/call") public String call(@RequestParam(value = "query", defaultValue = "你好,我的外号是影子,请记住呀") String query, @RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId ) { return chatClient.prompt(query) .advisors( a -> a.param(CONVERSATIONID, conversationId) ) .call().content(); }
@GetMapping("/messages") public List<Message> messages(@RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId) { return messageWindowChatMemory.get(conversationId); }}
效果
以会话”yingzi”发送消息,此时消息存储至 sqlite
从 sqlite 获取会话”yingzi”对应的消息
Mysql
MysqlMemoryConfig
package com.spring.ai.tutorial.advisor.memory.config;
import com.alibaba.cloud.ai.memory.jdbc.MysqlChatMemoryRepository;import org.springframework.beans.factory.annotation.Value;import org.springframework.context.annotation.Bean;import org.springframework.context.annotation.Configuration;import org.springframework.jdbc.core.JdbcTemplate;import org.springframework.jdbc.datasource.DriverManagerDataSource;
@Configurationpublic class MysqlMemoryConfig {
@Value("${spring.ai.chat.memory.repository.jdbc.mysql.jdbc-url}") private String mysqlJdbcUrl; @Value("${spring.ai.chat.memory.repository.jdbc.mysql.username}") private String mysqlUsername; @Value("${spring.ai.chat.memory.repository.jdbc.mysql.password}") private String mysqlPassword; @Value("${spring.ai.chat.memory.repository.jdbc.mysql.driver-class-name}") private String mysqlDriverClassName;
@Bean public MysqlChatMemoryRepository mysqlChatMemoryRepository() { DriverManagerDataSource dataSource = new DriverManagerDataSource(); dataSource.setDriverClassName(mysqlDriverClassName); dataSource.setUrl(mysqlJdbcUrl); dataSource.setUsername(mysqlUsername); dataSource.setPassword(mysqlPassword); JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); return MysqlChatMemoryRepository.mysqlBuilder() .jdbcTemplate(jdbcTemplate) .build(); }}
MysqlMemoryController
package com.spring.ai.tutorial.advisor.memory.controller;
import com.alibaba.cloud.ai.memory.jdbc.MysqlChatMemoryRepository;import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;import org.springframework.ai.chat.memory.MessageWindowChatMemory;import org.springframework.ai.chat.messages.Message;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;
import java.util.List;
import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATIONID;
@RestController@RequestMapping("/advisor/memory/mysql")public class MysqlMemoryController {
private final ChatClient chatClient; private final int MAXMESSAGES = 100; private final MessageWindowChatMemory messageWindowChatMemory;
public MysqlMemoryController(ChatClient.Builder builder, MysqlChatMemoryRepository mysqlChatMemoryRepository) { this.messageWindowChatMemory = MessageWindowChatMemory.builder() .chatMemoryRepository(mysqlChatMemoryRepository) .maxMessages(MAXMESSAGES) .build();
this.chatClient = builder .defaultAdvisors( MessageChatMemoryAdvisor.builder(messageWindowChatMemory) .build() ) .build(); }
@GetMapping("/call") public String call(@RequestParam(value = "query", defaultValue = "你好,我的外号是影子,请记住呀") String query, @RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId ) { return chatClient.prompt(query) .advisors( a -> a.param(CONVERSATIONID, conversationId) ) .call().content(); }
@GetMapping("/messages") public List<Message> messages(@RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId) { return messageWindowChatMemory.get(conversationId); }}
效果
以会话”yingzi”发送消息,此时消息存储至 mysql
消息被存储至 mysql 中
从 mysql 获取会话”yingzi”对应的消息
Redis
RedisMemoryConfig
package com.spring.ai.tutorial.advisor.memory.config;
import com.alibaba.cloud.ai.memory.redis.RedisChatMemoryRepository;import org.springframework.beans.factory.annotation.Value;import org.springframework.context.annotation.Bean;import org.springframework.context.annotation.Configuration;
@Configurationpublic class RedisMemoryConfig {
@Value("${spring.ai.memory.redis.host}") private String redisHost; @Value("${spring.ai.memory.redis.port}") private int redisPort; @Value("${spring.ai.memory.redis.password}") private String redisPassword; @Value("${spring.ai.memory.redis.timeout}") private int redisTimeout;
@Bean public RedisChatMemoryRepository redisChatMemoryRepository() { return RedisChatMemoryRepository.builder() .host(redisHost) .port(redisPort) // 若没有设置密码则注释该项// .password(redisPassword) .timeout(redisTimeout) .build(); }}
RedisMemoryController
package com.spring.ai.tutorial.advisor.memory.controller;
import com.alibaba.cloud.ai.memory.redis.RedisChatMemoryRepository;import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;import org.springframework.ai.chat.memory.MessageWindowChatMemory;import org.springframework.ai.chat.messages.Message;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;
import java.util.List;
import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATIONID;
@RestController@RequestMapping("/advisor/memory/redis")public class RedisMemoryController {
private final ChatClient chatClient; private final int MAXMESSAGES = 100; private final MessageWindowChatMemory messageWindowChatMemory;
public RedisMemoryController(ChatClient.Builder builder, RedisChatMemoryRepository redisChatMemoryRepository) { this.messageWindowChatMemory = MessageWindowChatMemory.builder() .chatMemoryRepository(redisChatMemoryRepository) .maxMessages(MAXMESSAGES) .build();
this.chatClient = builder .defaultAdvisors( MessageChatMemoryAdvisor.builder(messageWindowChatMemory) .build() ) .build(); }
@GetMapping("/call") public String call(@RequestParam(value = "query", defaultValue = "你好,我的外号是影子,请记住呀") String query, @RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId ) { return chatClient.prompt(query) .advisors( a -> a.param(CONVERSATIONID, conversationId) ) .call().content(); }
@GetMapping("/messages") public List<Message> messages(@RequestParam(value = "conversationid", defaultValue = "yingzi") String conversationId) { return messageWindowChatMemory.get(conversationId); }}
效果
以会话”yingzi”发送消息,此时消息存储至 redis
消息被存储至 redis 中
从 redis 获取会话”yingzi”对应的消息
Advisor 基础
基础提供了 SafeGuardAdvisor、SimpleLoggerAdvisor、ChatModelCallAdvisor、ChatModelStreamAdvisor、基于 BaseChatMemoryAdvisor 扩展的记忆功能
架构图
Advisor
advisor 基础信息配置
- name:指定名字,确保唯一性
- order:数值越小,执行越靠前
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.core.Ordered;
public interface Advisor extends Ordered { int DEFAULTCHATMEMORYPRECEDENCEORDER = -2147482648;
String getName();}
package org.springframework.core;
public interface Ordered { int HIGHESTPRECEDENCE = Integer.MINVALUE; int LOWESTPRECEDENCE = Integer.MAXVALUE;
int getOrder();}
CallAdvisor
call 调用,跟 AI 模型交互前、后的一些逻辑
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;
public interface CallAdvisor extends Advisor { ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain);}
StreamAdvisor
Stream 调用,跟 AI 模型交互前、后的一些逻
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import reactor.core.publisher.Flux;
public interface StreamAdvisor extends Advisor { Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain);}
BaseAdvisor
类说明:继承 CallAdvisor、StreamAdvisor,提供统一扩展点,统计拦截机制实现与 AI 模型请求和响应后的统一交互逻辑。
字段说明
字段名称 | 类型 | 描述 |
DEFAULTSCHEDULER | Scheduler | 定义默认调度器Schedulers.boundedElastic(),用于流式处理时的线程调度 |
方法说明
方法名称 | 描述 |
adviseCall | 同步调用,拦截call调用AI模型的请求和响应。子类实现before、after方法 |
adviseStream | 流式调用,拦截stream调用AI模型的请求和响应。子类实现before、after方法 |
before | AI模型请求前的逻辑,需要子类实现 |
after | AI模型响应后的逻辑,需要子类实现 |
package org.springframework.ai.chat.client.advisor.api;
import java.util.Objects;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import org.springframework.ai.chat.client.advisor.AdvisorUtils;import org.springframework.util.Assert;import reactor.core.publisher.Flux;import reactor.core.publisher.Mono;import reactor.core.scheduler.Scheduler;import reactor.core.scheduler.Schedulers;
public interface BaseAdvisor extends CallAdvisor, StreamAdvisor { Scheduler DEFAULTSCHEDULER = Schedulers.boundedElastic();
default ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); Assert.notNull(callAdvisorChain, "callAdvisorChain cannot be null"); ChatClientRequest processedChatClientRequest = this.before(chatClientRequest, callAdvisorChain); ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(processedChatClientRequest); return this.after(chatClientResponse, callAdvisorChain); }
default Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); Assert.notNull(streamAdvisorChain, "streamAdvisorChain cannot be null"); Assert.notNull(this.getScheduler(), "scheduler cannot be null"); Mono var10000 = Mono.just(chatClientRequest).publishOn(this.getScheduler()).map((request) -> this.before(request, streamAdvisorChain)); Objects.requireNonNull(streamAdvisorChain); Flux<ChatClientResponse> chatClientResponseFlux = var10000.flatMapMany(streamAdvisorChain::nextStream); return chatClientResponseFlux.map((response) -> { if (AdvisorUtils.onFinishReason().test(response)) { response = this.after(response, streamAdvisorChain); }
return response; }).onErrorResume((error) -> Flux.error(new IllegalStateException("Stream processing failed", error))); }
default String getName() { return this.getClass().getSimpleName(); }
ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain);
ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain);
default Scheduler getScheduler() { return DEFAULTSCHEDULER; }}
SafeGuardAdvisor
类的作用:在用户输入中检测敏感词,并在发现敏感词时阻止调用模型并返回预设的失败响应
敏感词匹配规则,实现设置一系列敏感词列表,校验提示词中是否包含敏感词
if (!CollectionUtils.isEmpty(this.sensitiveWords) && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { return createFailureResponse(chatClientRequest);}
package org.springframework.ai.chat.client.advisor;
import java.util.List;import java.util.Map;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import org.springframework.ai.chat.client.advisor.api.CallAdvisor;import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;import org.springframework.ai.chat.messages.AssistantMessage;import org.springframework.ai.chat.model.ChatResponse;import org.springframework.ai.chat.model.Generation;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import reactor.core.publisher.Flux;
public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor { private static final String DEFAULTFAILURERESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?"; private static final int DEFAULTORDER = 0; private final String failureResponse; private final List<String> sensitiveWords; private final int order;
public SafeGuardAdvisor(List<String> sensitiveWords) { this(sensitiveWords, "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?", 0); }
public SafeGuardAdvisor(List<String> sensitiveWords, String failureResponse, int order) { Assert.notNull(sensitiveWords, "Sensitive words must not be null!"); Assert.notNull(failureResponse, "Failure response must not be null!"); this.sensitiveWords = sensitiveWords; this.failureResponse = failureResponse; this.order = order; }
public static Builder builder() { return new Builder(); }
public String getName() { return this.getClass().getSimpleName(); }
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { return !CollectionUtils.isEmpty(this.sensitiveWords) && this.sensitiveWords.stream().anyMatch((w) -> chatClientRequest.prompt().getContents().contains(w)) ? this.createFailureResponse(chatClientRequest) : callAdvisorChain.nextCall(chatClientRequest); }
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { return !CollectionUtils.isEmpty(this.sensitiveWords) && this.sensitiveWords.stream().anyMatch((w) -> chatClientRequest.prompt().getContents().contains(w)) ? Flux.just(this.createFailureResponse(chatClientRequest)) : streamAdvisorChain.nextStream(chatClientRequest); }
private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest) { return ChatClientResponse.builder().chatResponse(ChatResponse.builder().generations(List.of(new Generation(new AssistantMessage(this.failureResponse)))).build()).context(Map.copyOf(chatClientRequest.context())).build(); }
public int getOrder() { return this.order; }
public static final class Builder { private List<String> sensitiveWords; private String failureResponse = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?"; private int order = 0;
private Builder() { }
public Builder sensitiveWords(List<String> sensitiveWords) { this.sensitiveWords = sensitiveWords; return this; }
public Builder failureResponse(String failureResponse) { this.failureResponse = failureResponse; return this; }
public Builder order(int order) { this.order = order; return this; }
public SafeGuardAdvisor build() { return new SafeGuardAdvisor(this.sensitiveWords, this.failureResponse, this.order); } }}
SimpleLoggerAdvisor
类的作用:主要用于日志记录,打印请求、响应等信息,默认 JSON 格式化输出
package org.springframework.ai.chat.client.advisor;
import java.util.function.Function;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.ai.chat.client.ChatClientMessageAggregator;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import org.springframework.ai.chat.client.advisor.api.CallAdvisor;import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;import org.springframework.ai.chat.model.ChatResponse;import org.springframework.ai.model.ModelOptionsUtils;import org.springframework.lang.Nullable;import reactor.core.publisher.Flux;
public class SimpleLoggerAdvisor implements CallAdvisor, StreamAdvisor { public static final Function<ChatClientRequest, String> DEFAULTREQUESTTOSTRING = ChatClientRequest::toString; public static final Function<ChatResponse, String> DEFAULTRESPONSETOSTRING = ModelOptionsUtils::toJsonStringPrettyPrinter; private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); private final Function<ChatClientRequest, String> requestToString; private final Function<ChatResponse, String> responseToString; private final int order;
public SimpleLoggerAdvisor() { this(DEFAULTREQUESTTOSTRING, DEFAULTRESPONSETOSTRING, 0); }
public SimpleLoggerAdvisor(int order) { this(DEFAULTREQUESTTOSTRING, DEFAULTRESPONSETOSTRING, order); }
public SimpleLoggerAdvisor(@Nullable Function<ChatClientRequest, String> requestToString, @Nullable Function<ChatResponse, String> responseToString, int order) { this.requestToString = requestToString != null ? requestToString : DEFAULTREQUESTTOSTRING; this.responseToString = responseToString != null ? responseToString : DEFAULTRESPONSETOSTRING; this.order = order; }
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { this.logRequest(chatClientRequest); ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest); this.logResponse(chatClientResponse); return chatClientResponse; }
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { this.logRequest(chatClientRequest); Flux<ChatClientResponse> chatClientResponses = streamAdvisorChain.nextStream(chatClientRequest); return (new ChatClientMessageAggregator()).aggregateChatClientResponse(chatClientResponses, this::logResponse); }
private void logRequest(ChatClientRequest request) { logger.debug("request: {}", this.requestToString.apply(request)); }
private void logResponse(ChatClientResponse chatClientResponse) { logger.debug("response: {}", this.responseToString.apply(chatClientResponse.chatResponse())); }
public String getName() { return this.getClass().getSimpleName(); }
public int getOrder() { return this.order; }
public String toString() { return SimpleLoggerAdvisor.class.getSimpleName(); }
public static Builder builder() { return new Builder(); }
public static final class Builder { private Function<ChatClientRequest, String> requestToString; private Function<ChatResponse, String> responseToString; private int order = 0;
private Builder() { }
public Builder requestToString(Function<ChatClientRequest, String> requestToString) { this.requestToString = requestToString; return this; }
public Builder responseToString(Function<ChatResponse, String> responseToString) { this.responseToString = responseToString; return this; }
public Builder order(int order) { this.order = order; return this; }
public SimpleLoggerAdvisor build() { return new SimpleLoggerAdvisor(this.requestToString, this.responseToString, this.order); } }}
ChatModelCallAdvisor
类的作用:使用注入的 ChatModel 实例执行 AI 模型 call 调用,若上下文中包含OUTPUTFORMAT,会将其附加到用户提示中,以指导模型生成符合预期格式的内容,通常作为增强器链的最后一个
package org.springframework.ai.chat.client.advisor;
import java.util.Map;import org.springframework.ai.chat.client.ChatClientAttributes;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import org.springframework.ai.chat.client.advisor.api.CallAdvisor;import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.ai.chat.model.ChatModel;import org.springframework.ai.chat.model.ChatResponse;import org.springframework.ai.chat.prompt.Prompt;import org.springframework.util.Assert;import org.springframework.util.StringUtils;
public final class ChatModelCallAdvisor implements CallAdvisor { private final ChatModel chatModel;
private ChatModelCallAdvisor(ChatModel chatModel) { Assert.notNull(chatModel, "chatModel cannot be null"); this.chatModel = chatModel; }
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); ChatClientRequest formattedChatClientRequest = augmentWithFormatInstructions(chatClientRequest); ChatResponse chatResponse = this.chatModel.call(formattedChatClientRequest.prompt()); return ChatClientResponse.builder().chatResponse(chatResponse).context(Map.copyOf(formattedChatClientRequest.context())).build(); }
private static ChatClientRequest augmentWithFormatInstructions(ChatClientRequest chatClientRequest) { String outputFormat = (String)chatClientRequest.context().get(ChatClientAttributes.OUTPUTFORMAT.getKey()); if (!StringUtils.hasText(outputFormat)) { return chatClientRequest; } else { Prompt augmentedPrompt = chatClientRequest.prompt().augmentUserMessage((userMessage) -> { UserMessage.Builder var10000 = userMessage.mutate(); String var10001 = userMessage.getText(); return var10000.text(var10001 + System.lineSeparator() + outputFormat).build(); }); return ChatClientRequest.builder().prompt(augmentedPrompt).context(Map.copyOf(chatClientRequest.context())).build(); } }
public String getName() { return "call"; }
public int getOrder() { return Integer.MAXVALUE; }
public static Builder builder() { return new Builder(); }
public static final class Builder { private ChatModel chatModel;
private Builder() { }
public Builder chatModel(ChatModel chatModel) { this.chatModel = chatModel; return this; }
public ChatModelCallAdvisor build() { return new ChatModelCallAdvisor(this.chatModel); } }}
ChatModelStreamAdvisor
类的作用:使用注入的 ChatModel 实例执行 AI 模型 Stream 调用,将模型返回的 Flux
- 默认使用 Schedulers.boundedElastic() 进行线程切换,以避免阻塞主线程或影响响应性
package org.springframework.ai.chat.client.advisor;
import java.util.Map;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;import org.springframework.ai.chat.model.ChatModel;import org.springframework.util.Assert;import reactor.core.publisher.Flux;import reactor.core.scheduler.Schedulers;
public final class ChatModelStreamAdvisor implements StreamAdvisor { private final ChatModel chatModel;
private ChatModelStreamAdvisor(ChatModel chatModel) { Assert.notNull(chatModel, "chatModel cannot be null"); this.chatModel = chatModel; }
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); return this.chatModel.stream(chatClientRequest.prompt()).map((chatResponse) -> ChatClientResponse.builder().chatResponse(chatResponse).context(Map.copyOf(chatClientRequest.context())).build()).publishOn(Schedulers.boundedElastic()); }
public String getName() { return "stream"; }
public int getOrder() { return Integer.MAXVALUE; }
public static Builder builder() { return new Builder(); }
public static final class Builder { private ChatModel chatModel;
private Builder() { }
public Builder chatModel(ChatModel chatModel) { this.chatModel = chatModel; return this; }
public ChatModelStreamAdvisor build() { return new ChatModelStreamAdvisor(this.chatModel); } }}
BaseChatMemoryAdvisor 解读篇
BaseChatMemoryAdvisor
类说明:从传入的上下文 Map 中提取会话 Id,若不存在则使用默认值
package org.springframework.ai.chat.client.advisor.api;
import java.util.Map;import org.springframework.util.Assert;
public interface BaseChatMemoryAdvisor extends BaseAdvisor { default String getConversationId(Map<String, Object> context, String defaultConversationId) { Assert.notNull(context, "context cannot be null"); Assert.noNullElements(context.keySet().toArray(), "context cannot contain null keys"); Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty"); return context.containsKey("chatmemoryconversationid") ? context.get("chatmemoryconversationid").toString() : defaultConversationId; }}
MessageChatMemoryAdvisor
类的说明:消息记忆存储的 Advisor 类
字段说明
字段名称 | 类型 | 描述 |
order | int | 指定顺序 |
defaultConversationId | String | 默认会话Id |
chatMemory | ChatMemory | 聊天记忆接口 |
scheduler | Scheduler | 流式处理时的线程调度 |
方法说明
方法名称 | 描述 |
before | 1. 从ChatMemory中取出历史消息 2. 当前消息加入ChatMemory 3. 整合当前消息+历史消息 |
after | 将模型响应的消息加入ChatMemory |
adviseStream | 覆盖了BaseAdvisor默认实现逻辑 - 注:在将多个流式响应合并成一个完整响应对象后,在调用after,确保只保留完整的模型输出,避免部分信息写入memory导致混乱 |
package org.springframework.ai.chat.client.advisor;
import java.util.ArrayList;import java.util.List;import java.util.Objects;import org.springframework.ai.chat.client.ChatClientMessageAggregator;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import org.springframework.ai.chat.client.advisor.api.AdvisorChain;import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;import org.springframework.ai.chat.memory.ChatMemory;import org.springframework.ai.chat.messages.Message;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.util.Assert;import reactor.core.publisher.Flux;import reactor.core.publisher.Mono;import reactor.core.scheduler.Scheduler;
public final class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor { private final ChatMemory chatMemory; private final String defaultConversationId; private final int order; private final Scheduler scheduler;
private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order, Scheduler scheduler) { Assert.notNull(chatMemory, "chatMemory cannot be null"); Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty"); Assert.notNull(scheduler, "scheduler cannot be null"); this.chatMemory = chatMemory; this.defaultConversationId = defaultConversationId; this.order = order; this.scheduler = scheduler; }
public int getOrder() { return this.order; }
public Scheduler getScheduler() { return this.scheduler; }
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) { String conversationId = this.getConversationId(chatClientRequest.context(), this.defaultConversationId); List<Message> memoryMessages = this.chatMemory.get(conversationId); List<Message> processedMessages = new ArrayList(memoryMessages); processedMessages.addAll(chatClientRequest.prompt().getInstructions()); ChatClientRequest processedChatClientRequest = chatClientRequest.mutate().prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build()).build(); UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); this.chatMemory.add(conversationId, userMessage); return processedChatClientRequest; }
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { List<Message> assistantMessages = new ArrayList(); if (chatClientResponse.chatResponse() != null) { assistantMessages = chatClientResponse.chatResponse().getResults().stream().map((g) -> g.getOutput()).toList(); }
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), assistantMessages); return chatClientResponse; }
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { Scheduler scheduler = this.getScheduler(); Mono var10000 = Mono.just(chatClientRequest).publishOn(scheduler).map((request) -> this.before(request, streamAdvisorChain)); Objects.requireNonNull(streamAdvisorChain); return var10000.flatMapMany(streamAdvisorChain::nextStream).transform((flux) -> (new ChatClientMessageAggregator()).aggregateChatClientResponse(flux, (response) -> this.after(response, streamAdvisorChain))); }
public static Builder builder(ChatMemory chatMemory) { return new Builder(chatMemory); }
public static final class Builder { private String conversationId = "default"; private int order = -2147482648; private Scheduler scheduler; private ChatMemory chatMemory;
private Builder(ChatMemory chatMemory) { this.scheduler = BaseAdvisor.DEFAULTSCHEDULER; this.chatMemory = chatMemory; }
public Builder conversationId(String conversationId) { this.conversationId = conversationId; return this; }
public Builder order(int order) { this.order = order; return this; }
public Builder scheduler(Scheduler scheduler) { this.scheduler = scheduler; return this; }
public MessageChatMemoryAdvisor build() { return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler); } }}
PromptChatMemoryAdvisor
类的说明:将聊天记忆嵌入到系统提示词的 Advisor 类
字段说明
字段名称 | 类型 | 描述 |
order | int | 指定顺序 |
defaultConversationId | String | 默认会话Id |
chatMemory | ChatMemory | 聊天记忆接口 |
scheduler | Scheduler | 流式处理时的线程调度 |
systemPromptTemplate | PromptTemplate | 当前使用的系统提示模板 |
方法说明
方法名称 | 描述 |
before | 1. 从ChatMemory中取出历史消息 2. 当前消息加入ChatMemory 3. 将历史消息结合系统提示消息作为最新的系统提示 |
after | 将模型响应的消息加入ChatMemory |
adviseStream | 覆盖了BaseAdvisor默认实现逻辑 - 注:在将多个流式响应合并成一个完整响应对象后,在调用after,确保只保留完整的模型输出,避免部分信息写入memory导致混乱 |
package org.springframework.ai.chat.client.advisor;
import java.util.ArrayList;import java.util.List;import java.util.Map;import java.util.Objects;import java.util.stream.Collectors;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.ai.chat.client.ChatClientMessageAggregator;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import org.springframework.ai.chat.client.advisor.api.AdvisorChain;import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;import org.springframework.ai.chat.memory.ChatMemory;import org.springframework.ai.chat.messages.Message;import org.springframework.ai.chat.messages.MessageType;import org.springframework.ai.chat.messages.SystemMessage;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.ai.chat.prompt.PromptTemplate;import org.springframework.util.Assert;import reactor.core.publisher.Flux;import reactor.core.publisher.Mono;import reactor.core.scheduler.Scheduler;
public final class PromptChatMemoryAdvisor implements BaseChatMemoryAdvisor { private static final Logger logger = LoggerFactory.getLogger(PromptChatMemoryAdvisor.class); private static final PromptTemplate DEFAULTSYSTEMPROMPTTEMPLATE = new PromptTemplate("{instructions}\n\nUse the conversation memory from the MEMORY section to provide accurate answers.\n\n---------------------\nMEMORY:\n{memory}\n---------------------\n\n"); private final PromptTemplate systemPromptTemplate; private final String defaultConversationId; private final int order; private final Scheduler scheduler; private final ChatMemory chatMemory;
private PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order, Scheduler scheduler, PromptTemplate systemPromptTemplate) { Assert.notNull(chatMemory, "chatMemory cannot be null"); Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty"); Assert.notNull(scheduler, "scheduler cannot be null"); Assert.notNull(systemPromptTemplate, "systemPromptTemplate cannot be null"); this.chatMemory = chatMemory; this.defaultConversationId = defaultConversationId; this.order = order; this.scheduler = scheduler; this.systemPromptTemplate = systemPromptTemplate; }
public static Builder builder(ChatMemory chatMemory) { return new Builder(chatMemory); }
public int getOrder() { return this.order; }
public Scheduler getScheduler() { return this.scheduler; }
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) { String conversationId = this.getConversationId(chatClientRequest.context(), this.defaultConversationId); List<Message> memoryMessages = this.chatMemory.get(conversationId); logger.debug("[PromptChatMemoryAdvisor.before] Memory before processing for conversationId={}: {}", conversationId, memoryMessages); String memory = (String)memoryMessages.stream().filter((m) -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT).map((m) -> { String var10000 = String.valueOf(m.getMessageType()); return var10000 + ":" + m.getText(); }).collect(Collectors.joining(System.lineSeparator())); SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage(); String augmentedSystemText = this.systemPromptTemplate.render(Map.of("instructions", systemMessage.getText(), "memory", memory)); ChatClientRequest processedChatClientRequest = chatClientRequest.mutate().prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText)).build(); UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); this.chatMemory.add(conversationId, userMessage); return processedChatClientRequest; }
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { List<Message> assistantMessages = new ArrayList(); if (chatClientResponse.chatResponse() != null) { assistantMessages = chatClientResponse.chatResponse().getResults().stream().map((g) -> g.getOutput()).toList(); } else if (chatClientResponse.chatResponse() != null && chatClientResponse.chatResponse().getResult() != null && chatClientResponse.chatResponse().getResult().getOutput() != null) { assistantMessages = List.of(chatClientResponse.chatResponse().getResult().getOutput()); }
if (!assistantMessages.isEmpty()) { this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), assistantMessages); logger.debug("[PromptChatMemoryAdvisor.after] Added ASSISTANT messages to memory for conversationId={}: {}", this.getConversationId(chatClientResponse.context(), this.defaultConversationId), assistantMessages); List<Message> memoryMessages = this.chatMemory.get(this.getConversationId(chatClientResponse.context(), this.defaultConversationId)); logger.debug("[PromptChatMemoryAdvisor.after] Memory after ASSISTANT add for conversationId={}: {}", this.getConversationId(chatClientResponse.context(), this.defaultConversationId), memoryMessages); }
return chatClientResponse; }
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { Scheduler scheduler = this.getScheduler(); Mono var10000 = Mono.just(chatClientRequest).publishOn(scheduler).map((request) -> this.before(request, streamAdvisorChain)); Objects.requireNonNull(streamAdvisorChain); return var10000.flatMapMany(streamAdvisorChain::nextStream).transform((flux) -> (new ChatClientMessageAggregator()).aggregateChatClientResponse(flux, (response) -> this.after(response, streamAdvisorChain))); }
public static final class Builder { private PromptTemplate systemPromptTemplate; private String conversationId; private int order; private Scheduler scheduler; private ChatMemory chatMemory;
private Builder(ChatMemory chatMemory) { this.systemPromptTemplate = PromptChatMemoryAdvisor.DEFAULTSYSTEMPROMPTTEMPLATE; this.conversationId = "default"; this.order = -2147482648; this.scheduler = BaseAdvisor.DEFAULTSCHEDULER; this.chatMemory = chatMemory; }
public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) { this.systemPromptTemplate = systemPromptTemplate; return this; }
public Builder conversationId(String conversationId) { this.conversationId = conversationId; return this; }
public Builder scheduler(Scheduler scheduler) { this.scheduler = scheduler; return this; }
public Builder order(int order) { this.order = order; return this; }
public PromptChatMemoryAdvisor build() { return new PromptChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler, this.systemPromptTemplate); } }}
ChatMemory
类说明:管理会话聊天记忆的接口,提供了保存、获取、清除对话消息的基本功能
字段说明
字段名称 | 类型 | 描述 |
CONVERSATIONID | String | 会话Id。当作键,方便提取对应的List |
方法说明
方法名称 | 描述 |
add | 添加消息到指定会话Id中 |
get | 根据指定会话Id获取消息 |
clear | 根据会话Id清除消息 |
package org.springframework.ai.chat.memory;
import java.util.List;import org.springframework.ai.chat.messages.Message;import org.springframework.util.Assert;
public interface ChatMemory { String DEFAULTCONVERSATIONID = "default"; String CONVERSATIONID = "chatmemoryconversationid";
default void add(String conversationId, Message message) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); Assert.notNull(message, "message cannot be null"); this.add(conversationId, List.of(message)); }
void add(String conversationId, List<Message> messages);
List<Message> get(String conversationId);
void clear(String conversationId);}
MessageWindowChatMemory
类说明:消息窗口类,提供了保存、获取、清除对话消息的基本功能
字段说明
字段名称 | 类型 | 描述 |
maxMessages | int | 当前会话最多保留的消息数量 |
chatMemoryRepository | ChatMemoryRepository | 存储后端,实现该接口可拓展内存、Mysql、Redis、ES等数据库存储消息 |
其他方法说明
方法名称 | 描述 |
process | 用于控制消息数量,核心逻辑下 1. 新增 SystemMessage 时,清除之前的 SystemMessage 2. 若消息数超过限制,优先保留SystemMessage |
package org.springframework.ai.chat.memory;
import java.util.ArrayList;import java.util.HashSet;import java.util.List;import java.util.Objects;import java.util.Set;import java.util.stream.Stream;import org.springframework.ai.chat.messages.Message;import org.springframework.ai.chat.messages.SystemMessage;import org.springframework.util.Assert;
public final class MessageWindowChatMemory implements ChatMemory { private static final int DEFAULTMAXMESSAGES = 20; private final ChatMemoryRepository chatMemoryRepository; private final int maxMessages;
private MessageWindowChatMemory(ChatMemoryRepository chatMemoryRepository, int maxMessages) { Assert.notNull(chatMemoryRepository, "chatMemoryRepository cannot be null"); Assert.isTrue(maxMessages > 0, "maxMessages must be greater than 0"); this.chatMemoryRepository = chatMemoryRepository; this.maxMessages = maxMessages; }
public void add(String conversationId, List<Message> messages) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId); List<Message> processedMessages = this.process(memoryMessages, messages); this.chatMemoryRepository.saveAll(conversationId, processedMessages); }
public List<Message> get(String conversationId) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); return this.chatMemoryRepository.findByConversationId(conversationId); }
public void clear(String conversationId) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); this.chatMemoryRepository.deleteByConversationId(conversationId); }
private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) { List<Message> processedMessages = new ArrayList(); Set<Message> memoryMessagesSet = new HashSet(memoryMessages); Stream var10000 = newMessages.stream(); Objects.requireNonNull(SystemMessage.class); boolean hasNewSystemMessage = var10000.filter(SystemMessage.class::isInstance).anyMatch((messagex) -> !memoryMessagesSet.contains(messagex)); var10000 = memoryMessages.stream().filter((messagex) -> !hasNewSystemMessage || !(messagex instanceof SystemMessage)); Objects.requireNonNull(processedMessages); var10000.forEach(processedMessages::add); processedMessages.addAll(newMessages); if (processedMessages.size() <= this.maxMessages) { return processedMessages; } else { int messagesToRemove = processedMessages.size() - this.maxMessages; List<Message> trimmedMessages = new ArrayList(); int removed = 0;
for(Message message : processedMessages) { if (!(message instanceof SystemMessage) && removed < messagesToRemove) { ++removed; } else { trimmedMessages.add(message); } }
return trimmedMessages; } }
public static Builder builder() { return new Builder(); }
public static final class Builder { private ChatMemoryRepository chatMemoryRepository; private int maxMessages = 20;
private Builder() { }
public Builder chatMemoryRepository(ChatMemoryRepository chatMemoryRepository) { this.chatMemoryRepository = chatMemoryRepository; return this; }
public Builder maxMessages(int maxMessages) { this.maxMessages = maxMessages; return this; }
public MessageWindowChatMemory build() { if (this.chatMemoryRepository == null) { this.chatMemoryRepository = new InMemoryChatMemoryRepository(); }
return new MessageWindowChatMemory(this.chatMemoryRepository, this.maxMessages); } }}
ChatMemoryRepository
package org.springframework.ai.chat.memory;
import java.util.List;import org.springframework.ai.chat.messages.Message;
public interface ChatMemoryRepository { List<String> findConversationIds();
List<Message> findByConversationId(String conversationId);
void saveAll(String conversationId, List<Message> messages);
void deleteByConversationId(String conversationId);}
InMemoryChatMemoryRepository
类说明:基于内存的实际存储数据,维护一个会话 Id 到消息列表的键值对
package org.springframework.ai.chat.memory;
import java.util.ArrayList;import java.util.List;import java.util.Map;import java.util.concurrent.ConcurrentHashMap;import org.springframework.ai.chat.messages.Message;import org.springframework.util.Assert;
public final class InMemoryChatMemoryRepository implements ChatMemoryRepository { Map<String, List<Message>> chatMemoryStore = new ConcurrentHashMap();
public List<String> findConversationIds() { return new ArrayList(this.chatMemoryStore.keySet()); }
public List<Message> findByConversationId(String conversationId) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); List<Message> messages = (List)this.chatMemoryStore.get(conversationId); return (List<Message>)(messages != null ? new ArrayList(messages) : List.of()); }
public void saveAll(String conversationId, List<Message> messages) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); this.chatMemoryStore.put(conversationId, messages); }
public void deleteByConversationId(String conversationId) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); this.chatMemoryStore.remove(conversationId); }}
问题
MessageChatMemoryAdvisor 和 PromptChatMemoryAdvisor 的区别是什么?
PromptChatMemoryAdvisor
-
有些模型可能可能不支持 message
- 如本地部署,LLaMA、BLOOM 等 text-in/text-out 模型
-
调试时,希望快速看到完整上下文
MessageChatMemoryAdvisor
- 需要精确控制消息类型(用户、系统、助手)
- 使用 OpenAI GPT-3.5/4 等 chat 模型
总结:MessageChatMemoryAdvisor 是面向结构化对话记忆的最佳实践,而 PromptChatMemoryAdvisor 是面向文本提示增强的经典方案
为什么需要覆盖 adviseStream 方法?
在将多个流式响应合并成一个完整响应对象后,在调用 after,确保只保留完整的模型输出,避免部分信息写入 memory 导致混
AdvisorChain 链
管理一系列 Advisor调用链路
AdvisorChain
接口类作用:组织和管理多个 Advsir
package org.springframework.ai.chat.client.advisor.api;
import io.micrometer.observation.ObservationRegistry;
public interface AdvisorChain { default ObservationRegistry getObservationRegistry() { return ObservationRegistry.NOOP; }}
CallAdvisorChain
接口类:call 增强链接口,链式调用机制将请求传递给下一个增强器
- 支持获取链中所有增强器集合,便于调试
package org.springframework.ai.chat.client.advisor.api;
import java.util.List;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;
public interface CallAdvisorChain extends AdvisorChain { ChatClientResponse nextCall(ChatClientRequest chatClientRequest);
List<CallAdvisor> getCallAdvisors();}
StreamAdvisorChain
接口类:Stream 增强链接口,链式调用机制将请求传递给下一个增强器
- 支持获取链中所有增强器集合,便于调试
package org.springframework.ai.chat.client.advisor.api;
import java.util.List;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import reactor.core.publisher.Flux;
public interface StreamAdvisorChain extends AdvisorChain { Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest);
List<StreamAdvisor> getStreamAdvisors();}
BaseAdvisorChain
接口类:统一 CallAdvisorChain、StreamAdvisorChain,逻辑复用
public interface BaseAdvisorChain extends CallAdvisorChain, StreamAdvisorChain {}
DefaultAroundAdvisorChain
- 管理多个增强器的执行顺序,通过 reOrder 方法
- 采取责任链调用机制,使用 nextCall、nextStream 方法将请求传递给下一个增强器
- 支持观测日志记录,集成 Micrometer Observations,可记录每个增强器的执行上下文和耗时
package org.springframework.ai.chat.client.advisor;
import io.micrometer.observation.Observation;import io.micrometer.observation.ObservationConvention;import io.micrometer.observation.ObservationRegistry;import java.util.ArrayList;import java.util.Deque;import java.util.List;import java.util.Objects;import java.util.concurrent.ConcurrentLinkedDeque;import org.springframework.ai.chat.client.ChatClientRequest;import org.springframework.ai.chat.client.ChatClientResponse;import org.springframework.ai.chat.client.advisor.api.Advisor;import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;import org.springframework.ai.chat.client.advisor.api.CallAdvisor;import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext;import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation;import org.springframework.ai.chat.client.advisor.observation.DefaultAdvisorObservationConvention;import org.springframework.ai.template.TemplateRenderer;import org.springframework.ai.template.st.StTemplateRenderer;import org.springframework.core.OrderComparator;import org.springframework.lang.Nullable;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import reactor.core.publisher.Flux;
public class DefaultAroundAdvisorChain implements BaseAdvisorChain { public static final AdvisorObservationConvention DEFAULTOBSERVATIONCONVENTION = new DefaultAdvisorObservationConvention(); private static final TemplateRenderer DEFAULTTEMPLATERENDERER = StTemplateRenderer.builder().build(); private final List<CallAdvisor> originalCallAdvisors; private final List<StreamAdvisor> originalStreamAdvisors; private final Deque<CallAdvisor> callAdvisors; private final Deque<StreamAdvisor> streamAdvisors; private final ObservationRegistry observationRegistry; private final TemplateRenderer templateRenderer;
DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, @Nullable TemplateRenderer templateRenderer, Deque<CallAdvisor> callAdvisors, Deque<StreamAdvisor> streamAdvisors) { Assert.notNull(observationRegistry, "the observationRegistry must be non-null"); Assert.notNull(callAdvisors, "the callAdvisors must be non-null"); Assert.notNull(streamAdvisors, "the streamAdvisors must be non-null"); this.observationRegistry = observationRegistry; this.templateRenderer = templateRenderer != null ? templateRenderer : DEFAULTTEMPLATERENDERER; this.callAdvisors = callAdvisors; this.streamAdvisors = streamAdvisors; this.originalCallAdvisors = List.copyOf(callAdvisors); this.originalStreamAdvisors = List.copyOf(streamAdvisors); }
public static Builder builder(ObservationRegistry observationRegistry) { return new Builder(observationRegistry); }
public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); if (this.callAdvisors.isEmpty()) { throw new IllegalStateException("No CallAdvisors available to execute"); } else { CallAdvisor advisor = (CallAdvisor)this.callAdvisors.pop(); AdvisorObservationContext observationContext = AdvisorObservationContext.builder().advisorName(advisor.getName()).chatClientRequest(chatClientRequest).order(advisor.getOrder()).build(); return (ChatClientResponse)AdvisorObservationDocumentation.AIADVISOR.observation((ObservationConvention)null, DEFAULTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry).observe(() -> advisor.adviseCall(chatClientRequest, this)); } }
public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); return Flux.deferContextual((contextView) -> { if (this.streamAdvisors.isEmpty()) { return Flux.error(new IllegalStateException("No StreamAdvisors available to execute")); } else { StreamAdvisor advisor = (StreamAdvisor)this.streamAdvisors.pop(); AdvisorObservationContext observationContext = AdvisorObservationContext.builder().advisorName(advisor.getName()).chatClientRequest(chatClientRequest).order(advisor.getOrder()).build(); Observation observation = AdvisorObservationDocumentation.AIADVISOR.observation((ObservationConvention)null, DEFAULTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); observation.parentObservation((Observation)contextView.getOrDefault("micrometer.observation", (Object)null)).start(); return Flux.defer(() -> { Flux var10000 = advisor.adviseStream(chatClientRequest, this); Objects.requireNonNull(observation); return var10000.doOnError(observation::error).doFinally((s) -> observation.stop()).contextWrite((ctx) -> ctx.put("micrometer.observation", observation)); }); } }); }
public List<CallAdvisor> getCallAdvisors() { return this.originalCallAdvisors; }
public List<StreamAdvisor> getStreamAdvisors() { return this.originalStreamAdvisors; }
public ObservationRegistry getObservationRegistry() { return this.observationRegistry; }
public static class Builder { private final ObservationRegistry observationRegistry; private final Deque<CallAdvisor> callAdvisors; private final Deque<StreamAdvisor> streamAdvisors; private TemplateRenderer templateRenderer;
public Builder(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; this.callAdvisors = new ConcurrentLinkedDeque(); this.streamAdvisors = new ConcurrentLinkedDeque(); }
public Builder templateRenderer(TemplateRenderer templateRenderer) { this.templateRenderer = templateRenderer; return this; }
public Builder push(Advisor advisor) { Assert.notNull(advisor, "the advisor must be non-null"); return this.pushAll(List.of(advisor)); }
public Builder pushAll(List<? extends Advisor> advisors) { Assert.notNull(advisors, "the advisors must be non-null"); Assert.noNullElements(advisors, "the advisors must not contain null elements"); if (!CollectionUtils.isEmpty(advisors)) { List<CallAdvisor> callAroundAdvisorList = advisors.stream().filter((a) -> a instanceof CallAdvisor).map((a) -> (CallAdvisor)a).toList(); if (!CollectionUtils.isEmpty(callAroundAdvisorList)) { Deque var10001 = this.callAdvisors; Objects.requireNonNull(var10001); callAroundAdvisorList.forEach(var10001::push); }
List<StreamAdvisor> streamAroundAdvisorList = advisors.stream().filter((a) -> a instanceof StreamAdvisor).map((a) -> (StreamAdvisor)a).toList(); if (!CollectionUtils.isEmpty(streamAroundAdvisorList)) { Deque var4 = this.streamAdvisors; Objects.requireNonNull(var4); streamAroundAdvisorList.forEach(var4::push); }
this.reOrder(); }
return this; }
private void reOrder() { ArrayList<CallAdvisor> callAdvisors = new ArrayList(this.callAdvisors); OrderComparator.sort(callAdvisors); this.callAdvisors.clear(); Deque var10001 = this.callAdvisors; Objects.requireNonNull(var10001); callAdvisors.forEach(var10001::addLast); ArrayList<StreamAdvisor> streamAdvisors = new ArrayList(this.streamAdvisors); OrderComparator.sort(streamAdvisors); this.streamAdvisors.clear(); var10001 = this.streamAdvisors; Objects.requireNonNull(var10001); streamAdvisors.forEach(var10001::addLast); }
public DefaultAroundAdvisorChain build() { return new DefaultAroundAdvisorChain(this.observationRegistry, this.templateRenderer, this.callAdvisors, this.streamAdvisors); } }}