持久化
Spring AI Alibaba Graph 具有 内置的持久化层,通过检查点(Checkpointers)实现。当您使用检查点编译图时,检查点会在每个超级步骤(super-step)保存图状态的检查点。这些检查点保存到一个会话(thread)中,可以在图执行后访问。
由于会话允许在执行后访问图的状态,因此几个强大的功能都成为可能,包括人在回路中(human-in-the-loop)、内存、时间旅行和容错能力。下面,我们将详细讨论这些概念。
会话
会话是分配给检查点器保存的每个检查点的唯一 ID 或会话标识符。它包含一系列运行的累积状态。当执行运行时,图的底层状态将被持久化到会话。
当使用检查点调用图时,您必须在配置的 RunnableConfig 中指定一个 threadId。
RunnableConfig config = RunnableConfig.builder()
.threadId("1")
.build();
可以检索会话的当前和历史状态。要持久化状态,必须在执行运行之前创建会话 。
检查点(Checkpoints)
会话在特定时间点的状态称为检查点。检查点是在每个超级步骤保存的图状态快照,由 StateSnapshot 对象表示,具有以下关键属性:
config: 与此检查点关联的配置。metadata: 与此检查点关联的元数据。values: 此时状态通道的值。next: 图中下一个要执行的节点名称元组。tasks: 包含有关下一个要执行的任务的信息的PregelTask对象元组。
检查点是持久化的,可以用于在稍后的时间恢复会话的状态。
让我们看看当一个简单的图被调用时保存了哪些检查点:
import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static com.alibaba.cloud.ai.graph.StateGraph.END;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;
// 定义状态策略
KeyStrategyFactory keyStrategyFactory = () -> {
Map<String, KeyStrategy> keyStrategyMap = new HashMap<>();
keyStrategyMap.put("foo", new ReplaceStrategy());
keyStrategyMap.put("bar", new AppendStrategy());
return keyStrategyMap;
};
// 定义节点操作
var nodeA = node_async(state -> {
return Map.of("foo", "a", "bar", List.of("a"));
});
var nodeB = node_async(state -> {
return Map.of("foo", "b", "bar", List.of("b"));
});
// 创建图
StateGraph stateGraph = new StateGraph(keyStrategyFactory)
.addNode("node_a", nodeA)
.addNode("node_b", nodeB)
.addEdge(START, "node_a")
.addEdge("node_a", "node_b")
.addEdge("node_b", END);
// 配置检查点
SaverConfig saverConfig = SaverConfig.builder()
.register(new MemorySaver())
.build();
// 编译图
CompiledGraph graph = stateGraph.compile(
CompileConfig.builder()
.saverConfig(saverConfig)
.build()
);
// 运行图
RunnableConfig config = RunnableConfig.builder()
.threadId("1")
.build();
Map<String, Object> input = new HashMap<>();
input.put("foo", "");
graph.invoke(input, config);
运行图后,我们期望看到恰好 4 个检查点:
- 空检查点,
START作为下一个要执行的节点 - 带有用户输入
{'foo': '', 'bar': []}和node_a作为下一个要执行的节点的检查点 - 带有
node_a的输出{'foo': 'a', 'bar': ['a']}和node_b作为下一个要执行的节点的检查点 - 带有
node_b的输出{'foo': 'b', 'bar': ['a', 'b']}且没有下一个要执行的节点的检查点
请注意,bar 通道值包含两个节点的输出,因为我们对 bar 通道使用了追加策略(AppendStrategy)。
获取状态
当与保存的图状态交互时,您必须指定一个会话标识符。您可以通过调用 graph.getState(config) 来查看图的最新状态。这将返回一个 StateSnapshot 对象,该对象对应于与配置中提供的会话 ID 关联的最新检查点,或者如果提供了检查点 ID,则对应 于该会话的检查点。
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
// 获取最新的状态快照
RunnableConfig config = RunnableConfig.builder()
.threadId("1")
.build();
StateSnapshot stateSnapshot = graph.getState(config);
System.out.println("Current state: " + stateSnapshot.state());
System.out.println("Current node: " + stateSnapshot.node());
// 获取特定 checkpoint_id 的状态快照
RunnableConfig configWithCheckpoint = RunnableConfig.builder()
.threadId("1")
.checkPointId("1ef663ba-28fe-6528-8002-5a559208592c")
.build();
StateSnapshot specificSnapshot = graph.getState(configWithCheckpoint);
System.out.println("Specific checkpoint state: " + specificSnapshot.state());
获取状态历史
您可以通过调用 graph.getStateHistory(config) 来获取给定会话的图执行的完整历史记录。这将返回与配置中提供的会话 ID 关联的 StateSnapshot 对象列表。重要的是,检查点将按时间顺序排序,最近的检查点/StateSnapshot 在列表的第一个位置。
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
import java.util.List;
RunnableConfig config = RunnableConfig.builder()
.threadId("1")
.build();
List<StateSnapshot> history = (List<StateSnapshot>) graph.getStateHistory(config);
System.out.println("State history:");
for (int i = 0; i < history.size(); i++) {
StateSnapshot snapshot = history.get(i);
System.out.printf("Step %d: %s
", i, snapshot.state());
System.out.printf(" Checkpoint ID: %s
", snapshot.config().checkPointId());
System.out.printf(" Node: %s
", snapshot.node());
}