人类反馈(Human-in-the-Loop)
在实际业务场景中,经常会遇到人类介入的场景,人类的不同操作将影响工作流不同的走向。Spring AI Alibaba Graph 提供了两种方式来实现人类反馈:
- InterruptionMetadata 模式:可以在任意节点随时中断,通过实现
InterruptableAction接口来控制中断时机 - interruptBefore 模式:需要提前在编译配置中定义中断点,在指定节点执行前中断
模式一:InterruptionMetadata 模式
InterruptionMetadata 模式允许节点在运行时动态决定是否需要中断,提供了最大的灵活性。节点通过实现 InterruptableAction 接口,可以在任意时刻返回 InterruptionMetadata 来中断执行。
优势
- 灵活性强:可以在任意节点根据运行时状 态决定是否中断
- 动态控制:中断逻辑由节点自身控制,不需要提前配置
- 状态感知:可以根据当前状态动态决定是否需要等待用户输入
定义带中断的 Graph
定义带中断的 Graph查看完整代码
import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.action.InterruptableAction;
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import static com.alibaba.cloud.ai.graph.StateGraph.END;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;
/**
* 定义带中断的 Graph
* 使用 InterruptableAction 实现中断,不需要 interruptBefore 配置
*/
public static CompiledGraph createGraphWithInterrupt() throws GraphStateException {
// 定义普通节点
var step1 = node_async(state -> {
return Map.of("messages", "Step 1");
});
// 定义可中断节点(实现 InterruptableAction)
var humanFeedback = new InterruptableNodeAction("human_feedback", "等待用户输入");
var step3 = node_async(state -> {
return Map.of("messages", "Step 3");
});
// 定义条件边:根据 human_feedback 的值决定路由
var evalHumanFeedback = edge_async(state -> {
var feedback = (String) state.value("human_feedback").orElse("unknown");
return (feedback.equals("next") || feedback.equals("back")) ? feedback : "unknown";
});
// 配置 KeyStrategyFactory
KeyStrategyFactory keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
keyStrategyHashMap.put("messages", new AppendStrategy());
keyStrategyHashMap.put("human_feedback", new ReplaceStrategy());
return keyStrategyHashMap;
};
// 构建 Graph
StateGraph builder = new StateGraph(keyStrategyFactory)
.addNode("step_1", step1)
.addNode("human_feedback", humanFeedback) // 使用可中断节点
.addNode("step_3", step3)
.addEdge(START, "step_1")
.addEdge("step_1", "human_feedback")
.addConditionalEdges("human_feedback", evalHumanFeedback,
Map.of("back", "step_1", "next", "step_3", "unknown", "human_feedback"))
.addEdge("step_3", END);
// 配置内存保存器(用于状态持久化)
var saver = new MemorySaver();
var compileConfig = CompileConfig.builder()
.saverConfig(SaverConfig.builder()
.register(saver)
.build())
// 不再需要 interruptBefore 配置,中断由 InterruptableAction 控制
.build();
return builder.compile(compileConfig);
}
实现 InterruptableNodeAction
实现 InterruptableNodeAction查看完整代码
/**
* 可中断的节点动作
* 实现 InterruptableAction 接口,可以在任意节点中断执行
*/
public static class InterruptableNodeAction implements AsyncNodeActionWithConfig, InterruptableAction {
private final String nodeId;
private final String message;
public InterruptableNodeAction(String nodeId, String message) {
this.nodeId = nodeId;
this.message = message;
}
@Override
public CompletableFuture<Map<String, Object>> apply(OverAllState state, RunnableConfig config) {
// 正常节点逻辑:更新状态
return CompletableFuture.completedFuture(Map.of("messages", message));
}
@Override
public Optional<InterruptionMetadata> interrupt(String nodeId, OverAllState state, RunnableConfig config) {
// 检查是否需要中断
// 如果状态中没有 human_feedback,则中断等待用户输入
Optional<Object> humanFeedback = state.value("human_feedback");
if (humanFeedback.isEmpty()) {
// 返回 InterruptionMetadata 来中断执行
InterruptionMetadata interruption = InterruptionMetadata.builder(nodeId, state)
.addMetadata("message", "等待用户输入...")
.addMetadata("node", nodeId)
.build();
return Optional.of(interruption);
}
// 如果已经有 human_feedback,继续执行
return Optional.empty();
}
}
执行 Graph 直到中断
执行 Graph 直到中断查看完整代码
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import reactor.core.publisher.Flux;
/**
* 执行 Graph 直到中断
* 检查流式输出中的 InterruptionMetadata
*/
public static InterruptionMetadata executeUntilInterrupt(CompiledGraph graph) {
// 初始输入
Map<String, Object> initialInput = Map.of("messages", "Step 0");
// 配置线程 ID
var invokeConfig = RunnableConfig.builder()
.threadId("Thread1")
.build();
// 用于保存最后一个输出
AtomicReference<NodeOutput> lastOutputRef = new AtomicReference<>();
// 运行 Graph 直到第一个中断点
graph.stream(initialInput, invokeConfig)
.doOnNext(event -> {
System.out.println("节点输出: " + event);
lastOutputRef.set(event);
})
.doOnError(error -> System.err.println("流错误: " + error.getMessage()))
.doOnComplete(() -> System.out.println("流完成"))
.blockLast();
// 检查最后一个输出是否是 InterruptionMetadata
NodeOutput lastOutput = lastOutputRef.get();
if (lastOutput instanceof InterruptionMetadata) {
System.out.println("
检测到中断: " + lastOutput);
return (InterruptionMetadata) lastOutput;
}
return null;
}
输出:
节点输出: NodeOutput{node=__START__, state={messages=[Step 0]}}
节点输出: NodeOutput{node=step_1, state={messages=[Step 0, Step 1]}}
检测到中断: InterruptionMetadata{node=human_feedback, state={messages=[Step 0, Step 1]}, metadata={message=等待用户输入..., node=human_feedback}}