为图添加持久化("记忆")
许多 AI 应用程序需要记忆来跨多个交互共享上下文。在 Spring AI Alibaba 中,通过 Checkpointer 为任何 StateGraph 提供记忆。
核心概念
在创建任何 Spring AI Alibaba 工作流时,可以通过以下方式设置持久化:
- 创建一个
Checkpointer,例如MemorySaver - 在编译图时通过
CompileConfig传递 Checkpointer - 使用
threadId来标识不同的会话
初始化配置
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
private static final Logger log = LoggerFactory.getLogger("Persistence");
定义状态和策略
状态是在图中所有节点之间共享的数据结构。Spring AI Alibaba 使用 KeyStrategyFactory 来定义状态键的行为。
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy;
import java.util.Map;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
// 自定义状态类
public class ConversationState extends OverAllState {
public ConversationState(Map<String, Object> initData) {
super(initData);
}
public Optional<List<String>> messages() {
return value("messages");
}
public Optional<String> userName() {
return value("user_name");
}
}
// 配置状态键策略
KeyStrategyFactory keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> strategies = new HashMap<>();
strategies.put("messages", new AppendStrategy()); // 消息追加
strategies.put("user_name", new ReplaceStrategy()); // 用户名替换
strategies.put("context", new ReplaceStrategy()); // 上下文替换
return strategies;
};
创建带工具的 Agent 节点
我们将创建一个简单的搜索工具来演示如何在持久化环境中使用工具。
定义工具函数
import java.util.function.Function;
// 搜索工具
public class SearchTool implements Function<SearchTool.Request, String> {
public record Request(String query) {}
@Override
public String apply(Request request) {
log.info("Executing search for: {}", request.query());
// 模拟搜索结果
return "Search result: The weather is cold with a low of 13 degrees";
}
}
创建 Agent 节点
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
class AgentNode implements NodeAction {
private final ChatClient chatClient;
public AgentNode(ChatClient.Builder chatClientBuilder, SearchTool searchTool) {
// 配置工具
FunctionCallback searchCallback = FunctionCallbackWrapper.builder(searchTool)
.withName("search")
.withDescription("Search for information, check weather, and retrieve data")
.build();
this.chatClient = chatClientBuilder
.defaultFunctions(searchCallback)
.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
ConversationState convState = (ConversationState) state;
// 获取最后一条消息
List<String> messages = convState.messages().orElse(List.of());
String lastMessage = messages.isEmpty() ? "" : messages.get(messages.size() - 1);
log.info("Processing message: {}", lastMessage);
// 调用 LLM(会自动处理工具调用)
String response = chatClient.prompt()
.user(lastMessage)
.call()
.content();
return Map.of("messages", response);
}
}
定义路由逻辑
import com.alibaba.cloud.ai.graph.action.EdgeAction;
class RouteMessage implements EdgeAction {
@Override
public String apply(OverAllState state) throws Exception {
ConversationState convState = (ConversationState) state;
// 获取消息列表
List<String> messages = convState.messages().orElse(List.of());
if (messages.isEmpty()) {
return "exit";
}
// 简单的路由逻辑:检查最后一条消息是否需要工具调用
String lastMessage = messages.get(messages.size() - 1);
// 如果消息包含工具调用相关内容,继续;否则结束
if (lastMessage.contains("search") || lastMessage.contains("weather")) {
return "continue";
}
return "exit";
}
}
构建带持久化的 Graph
不使用 Checkpointer
首先,让我们看看不使用持久化时的行为:
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.nodeasync;
import static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edgeasync;
// 配置 ChatClient
ChatClient.Builder chatClientBuilder = ChatClient.builder(chatModel);
// 创建工具和节点
SearchTool searchTool = new SearchTool();
var agentNode = nodeasync(new AgentNode(chatClientBuilder, searchTool));
var routeMessage = edgeasync(new RouteMessage());
// 构建 Graph(不使用 checkpointer)
StateGraph workflow = new StateGraph(keyStrategyFactory)
.addNode("agent", agentNode)
.addEdge(StateGraph.START, "agent")
.addConditionalEdges("agent", routeMessage,
Map.of(
"continue", "agent",
"exit", StateGraph.END
));
CompiledGraph graph = workflow.compile();
测试不带持久化的 Graph
// 第一次调用 - 介绍自己
log.info("=== First call - Introduction ===");
var result1 = graph.invoke(Map.of("messages", "Hi, I'm Alice, nice to meet you"));
List<String> messages1 = (List<String>) result1.data().get("messages");
log.info("Response: {}", messages1.get(messages1.size() - 1));
// 第二次调用 - 询问名字(没有持久化,无法记住)
log.info("=== Second call - Ask name ===");
var result2 = graph.invoke(Map.of("messages", "What's my name?"));
List<String> messages2 = (List<String>) result2.data().get("messages");
log.info("Response: {}", messages2.get(messages2.size() - 1));
输出(不带持久化):
=== First call - Introduction ===
Response: Hello Alice, nice to meet you too!
=== Second call - Ask name ===
Response: I don't have information about your name. Could you please tell me?
可以看到,没有持久化时,Graph 无法记住之前的对话内容。
添加持久化(记忆)
现在让我们添加 MemorySaver 来实现持久化:
import com.alibaba.cloud.ai.graph.checkpoint.MemorySaver;
import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.RunnableConfig;
// 创建 Checkpointer
var checkpointer = new MemorySaver();
// 配置持久化
var compileConfig = CompileConfig.builder()
.checkpointSaver(checkpointer)
.build();
// 编译带持久化的 Graph
CompiledGraph persistentGraph = workflow.compile(compileConfig);
测试带持久化的 Graph
// 创建运行配置(使用 threadId 标识会话)
var config = RunnableConfig.builder()
.threadId("user-alice-session")
.build();
// 第一次调用 - 介绍自己
log.info("=== First call with persistence - Introduction ===");
var result1 = persistentGraph.invoke(
Map.of("messages", "Hi, I'm Alice, nice to meet you"),
config
);
List<String> messages1 = (List<String>) result1.data().get("messages");
log.info("Response: {}", messages1.get(messages1.size() - 1));
// 第二次调用 - 询问名字(有持久化,可以记住)
log.info("=== Second call with persistence - Ask name ===");
var result2 = persistentGraph.invoke(
Map.of("messages", "What's my name?"),
config
);
List<String> messages2 = (List<String>) result2.data().get("messages");
log.info("Response: {}", messages2.get(messages2.size() - 1));
// 第三次调用 - 继续对话
log.info("=== Third call - Continue conversation ===");
var result3 = persistentGraph.invoke(
Map.of("messages", "What did I say in my first message?"),
config
);
List<String> messages3 = (List<String>) result3.data().get("messages");
log.info("Response: {}", messages3.get(messages3.size() - 1));
输出(带持久化):
=== First call with persistence - Introduction ===
Response: Hello Alice, nice to meet you too! How can I help you today?
=== Second call with persistence - Ask name ===
Response: Your name is Alice!
=== Third call - Continue conversation ===
Response: You said "Hi, I'm Alice, nice to meet you"
多会话隔离
使用不同的 threadId 可以创建完全独立的会话:
// Alice 的会话
var aliceConfig = RunnableConfig.builder()
.threadId("user-alice")
.build();
persistentGraph.invoke(Map.of("messages", "Hi, I'm Alice"), aliceConfig);
// Bob 的会话
var bobConfig = RunnableConfig.builder()
.threadId("user-bob")
.build();
persistentGraph.invoke(Map.of("messages", "Hi, I'm Bob"), bobConfig);
// Alice 询问名字 - 能记住
var aliceResult = persistentGraph.invoke(
Map.of("messages", "What's my name?"),
aliceConfig
);
log.info("Alice: {}", aliceResult.data().get("messages"));
// 输出: Your name is Alice
// Bob 询问名字 - 也能记住
var bobResult = persistentGraph.invoke(
Map.of("messages", "What's my name?"),
bobConfig
);
log.info("Bob: {}", bobResult.data().get("messages"));
// 输出: Your name is Bob