人类反馈(Human-in-the-Loop)
在实际业务场景中,经常会遇到人类介入的场景,人类的不同操作将影响工作流不同的走向。Spring AI Alibaba Graph 提供了两种方式来实现人类反馈:
- InterruptionMetadata 模式:可以在任意节点随时中断,通过实现
InterruptableAction接口来控制中断时机 - interruptBefore 模式:需要提前在编译配置中定义中断点,在指定节点执行前中断
模式一:InterruptionMetadata 模式
InterruptionMetadata 模式允许节点在运行时动态决定是否需要中断,提供了最大的灵活性。节点通过实现 InterruptableAction 接口,可以在任意时刻返回 InterruptionMetadata 来中断执行。
优势
- 灵活性强:可以在任意节点根据运行时状态决定是否中断
- 动态控制:中断逻辑由节点自身控制,不需要提前配置
- 状态感知:可以根据当前状态动态决定是否需要等待用户输入
定义带中断的 Graph
定义带中断的 Graph查看完整代码
import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.action.InterruptableAction;
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import static com.alibaba.cloud.ai.graph.StateGraph.END;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;
/**
* 定义带中断的 Graph
* 使用 InterruptableAction 实现中断,不需要 interruptBefore 配置
*/
public static CompiledGraph createGraphWithInterrupt() throws GraphStateException {
// 定义普通节点
var step1 = node_async(state -> {
return Map.of("messages", "Step 1");
});
// 定义可中断节点(实现 InterruptableAction)
var humanFeedback = new InterruptableNodeAction("human_feedback", "等待用户输入");
var step3 = node_async(state -> {
return Map.of("messages", "Step 3");
});
// 定义条件边:根据 human_feedback 的值决定路由
var evalHumanFeedback = edge_async(state -> {
var feedback = (String) state.value("human_feedback").orElse("unknown");
return (feedback.equals("next") || feedback.equals("back")) ? feedback : "unknown";
});
// 配置 KeyStrategyFactory
KeyStrategyFactory keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
keyStrategyHashMap.put("messages", new AppendStrategy());
keyStrategyHashMap.put("human_feedback", new ReplaceStrategy());
return keyStrategyHashMap;
};
// 构建 Graph
StateGraph builder = new StateGraph(keyStrategyFactory)
.addNode("step_1", step1)
.addNode("human_feedback", humanFeedback) // 使用可中断节点
.addNode("step_3", step3)
.addEdge(START, "step_1")
.addEdge("step_1", "human_feedback")
.addConditionalEdges("human_feedback", evalHumanFeedback,
Map.of("back", "step_1", "next", "step_3", "unknown", "human_feedback"))
.addEdge("step_3", END);
// 配置内存保存器(用于状态持久化)
var saver = new MemorySaver();
var compileConfig = CompileConfig.builder()
.saverConfig(SaverConfig.builder()
.register(saver)
.build())
// 不再需要 interruptBefore 配置,中断由 InterruptableAction 控制
.build();
return builder.compile(compileConfig);
}
实现 InterruptableNodeAction
实现 InterruptableNodeAction查看完整代码
/**
* 可中断的节点动作
* 实现 InterruptableAction 接口,可以在任意节点中断执行
*/
public static class InterruptableNodeAction implements AsyncNodeActionWithConfig, InterruptableAction {
private final String nodeId;
private final String message;
public InterruptableNodeAction(String nodeId, String message) {
this.nodeId = nodeId;
this.message = message;
}
@Override
public CompletableFuture<Map<String, Object>> apply(OverAllState state, RunnableConfig config) {
// 正常节点逻辑:更新状态
return CompletableFuture.completedFuture(Map.of("messages", message));
}
@Override
public Optional<InterruptionMetadata> interrupt(String nodeId, OverAllState state, RunnableConfig config) {
// 检查是否需要中断
// 如果状态中没有 human_feedback,则中断等待用户输入
Optional<Object> humanFeedback = state.value("human_feedback");
if (humanFeedback.isEmpty()) {
// 返回 InterruptionMetadata 来中断执行
InterruptionMetadata interruption = InterruptionMetadata.builder(nodeId, state)
.addMetadata("message", "等待用户输入...")
.addMetadata("node", nodeId)
.build();
return Optional.of(interruption);
}
// 如果已经有 human_feedback,继续执行
return Optional.empty();
}
}
执行 Graph 直到中断
执行 Graph 直到中断查看完整代码
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import reactor.core.publisher.Flux;
/**
* 执行 Graph 直到中断
* 检查流式输出中的 InterruptionMetadata
*/
public static InterruptionMetadata executeUntilInterrupt(CompiledGraph graph) {
// 初始输入
Map<String, Object> initialInput = Map.of("messages", "Step 0");
// 配置线程 ID
var invokeConfig = RunnableConfig.builder()
.threadId("Thread1")
.build();
// 用于保存 最后一个输出
AtomicReference<NodeOutput> lastOutputRef = new AtomicReference<>();
// 运行 Graph 直到第一个中断点
graph.stream(initialInput, invokeConfig)
.doOnNext(event -> {
System.out.println("节点输出: " + event);
lastOutputRef.set(event);
})
.doOnError(error -> System.err.println("流错误: " + error.getMessage()))
.doOnComplete(() -> System.out.println("流完成"))
.blockLast();
// 检查最后一个输出是否是 InterruptionMetadata
NodeOutput lastOutput = lastOutputRef.get();
if (lastOutput instanceof InterruptionMetadata) {
System.out.println("
检测到中断: " + lastOutput);
return (InterruptionMetadata) lastOutput;
}
return null;
}
输出:
节点输出: NodeOutput{node=__START__, state={messages=[Step 0]}}
节点输出: NodeOutput{node=step_1, state={messages=[Step 0, Step 1]}}
检测到中断: InterruptionMetadata{node=human_feedback, state={messages=[Step 0, Step 1]}, metadata={message=等待用户输入..., node=human_feedback}}
等待用户输入并更新状态
等待用户输入并更新状态查看完整代码
/**
* 等待用户输入并更新状态
*/
public static RunnableConfig waitUserInputAndUpdateState(CompiledGraph graph, InterruptionMetadata interruption) throws Exception {
var invokeConfig = RunnableConfig.builder()
.threadId("Thread1")
.build();
// 检查当前状态
System.out.printf("
--State before update--
%s
", graph.getState(invokeConfig));
// 模拟用户输入
var userInput = "back"; // "back" 表示返回上一个节点
System.out.printf("
--User Input--
用户选择: '%s'
", userInput);
// 更新状态:添加 human_feedback
// 使用 updateState 更新状态,传入中断时的节点 ID
var updatedConfig = graph.updateState(invokeConfig, Map.of("human_feedback", userInput), interruption.node());
// 检查更新后的状态
System.out.printf("--State after update--
%s
", graph.getState(updatedConfig));
return updatedConfig;
}
输出:
--State before update--
StateSnapshot{node=step_1, state={messages=[Step 0, Step 1]}, config=RunnableConfig{ threadId=Thread1, nextNode=human_feedback }}
--User Input--
用户选择: 'back'
--State after update--
StateSnapshot{node=step_1, state={messages=[Step 0, Step 1], human_feedback=back}, config=RunnableConfig{ threadId=Thread1, nextNode=human_feedback }}
继续执行 Graph
继续执行 Graph查看完整代码
/**
* 继续执行 Graph
* 使用 HUMAN_FEEDBACK_METADATA_KEY 来恢复执行
*/
public static void continueExecution(CompiledGraph graph, RunnableConfig updatedConfig) {
// 创建恢复配置,添加 HUMAN_FEEDBACK_METADATA_KEY
RunnableConfig resumeConfig = RunnableConfig.builder(updatedConfig)
.addMetadata(RunnableConfig.HUMAN_FEEDBACK_METADATA_KEY, "placeholder")
.build();
System.out.println("
--继续执行 Graph--");
// 继续执行 Graph(input 为 null,使用之前的状态)
graph.stream(null, resumeConfig)
.doOnNext(event -> System.out.println("节点输出: " + event))
.doOnError(error -> System.err.println("流错误: " + error.getMessage()))
.doOnComplete(() -> System.out.println("流完成"))
.blockLast();
}
输出:
--继续执行 Graph--
节点输出: NodeOutput{node=human_feedback, state={messages=[Step 0, Step 1], human_feedback=back}}
节点输出: NodeOutput{node=step_1, state={messages=[Step 0, Step 1], human_feedback=back}}
流完成
模式二:interruptBefore 模式
interruptBefore 模式需要在编译 Graph 时提前指定中断点,在指定节点执行前自动中断。这种方式适合已知的中断点,配置简单直接。
优势
- 配置简单:只需在编译配置中指定中断点
- 无需修改节点:普通节点即可,不需要实现特殊接口
- 明确的中断点:中断位置在编译时确定,易于理解和维护
定义带中断的 Graph
定义带中断的 Graph (interruptBefore 模式)查看完整代码
import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import java.util.HashMap;
import java.util.Map;
import static com.alibaba.cloud.ai.graph.StateGraph.END;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;
/**
* 定义带中断的 Graph
* 使用 interruptBefore 配置在指定节点前中断
*/
public static CompiledGraph createGraphWithInterrupt() throws GraphStateException {
// 定义节点
var step1 = node_async(state -> {
return Map.of("messages", "Step 1");
});
var humanFeedback = node_async(state -> {
return Map.of(); // 等待用户输入,不修改状态
});
var step3 = node_async(state -> {
return Map.of("messages", "Step 3");
});
// 定义条件边
var evalHumanFeedback = edge_async(state -> {
var feedback = (String) state.value("human_feedback").orElse("unknown");
return (feedback.equals("next") || feedback.equals("back")) ? feedback : "unknown";
});
// 配置 KeyStrategyFactory
KeyStrategyFactory keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
keyStrategyHashMap.put("messages", new AppendStrategy());
keyStrategyHashMap.put("human_feedback", new ReplaceStrategy());
return keyStrategyHashMap;
};
// 构建 Graph
StateGraph builder = new StateGraph(keyStrategyFactory)
.addNode("step_1", step1)
.addNode("human_feedback", humanFeedback)
.addNode("step_3", step3)
.addEdge(START, "step_1")
.addEdge("step_1", "human_feedback")
.addConditionalEdges("human_feedback", evalHumanFeedback,
Map.of("back", "step_1", "next", "step_3", "unknown", "human_feedback"))
.addEdge("step_3", END);
// 配置内存保存器和中断点
var saver = new MemorySaver();
var compileConfig = CompileConfig.builder()
.saverConfig(SaverConfig.builder()
.register(saver)
.build())
.interruptBefore("human_feedback") // 在 human_feedback 节点前中 断
.build();
return builder.compile(compileConfig);
}
执行 Graph 直到中断
执行 Graph 直到中断 (interruptBefore 模式)查看完整代码
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import java.util.Map;
import reactor.core.publisher.Flux;
/**
* 执行 Graph 直到中断
*/
public static void executeUntilInterrupt(CompiledGraph graph) {
// 初始输入
Map<String, Object> initialInput = Map.of("messages", "Step 0");
// 配置线程 ID
var invokeConfig = RunnableConfig.builder()
.threadId("Thread1")
.build();
// 运行 Graph 直到第一个中断点
graph.stream(initialInput, invokeConfig)
.doOnNext(event -> System.out.println(event))
.doOnError(error -> System.err.println("流错误: " + error.getMessage()))
.doOnComplete(() -> System.out.println("流完成"))
.blockLast();
}
输出:
NodeOutput{node=__START__, state={messages=[Step 0]}}
NodeOutput{node=step_1, state={messages=[Step 0, Step 1]}}
流完成
等待用户输入并更新状态
等待用户输入并更新状态 (interruptBefore 模式)查看完整代码
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import java.util.Map;
/**
* 等待用户输入并更新状态
*/
public static RunnableConfig waitUserInputAndUpdateState(CompiledGraph graph) throws Exception {
var invokeConfig = RunnableConfig.builder()
.threadId("Thread1")
.build();
// 检查当前状态
System.out.printf("--State before update--
%s
", graph.getState(invokeConfig));
// 模拟用户输入
var userInput = "back"; // "back" 表示返回上一个节点
System.out.printf("
--User Input--
用户选择: '%s'
", userInput);
// 更新状态(模拟 human_feedback 节点的输出)
// 注意:interruptBefore 模式下,传入 null 作为节点 ID
var updateConfig = graph.updateState(invokeConfig, Map.of("human_feedback", userInput), null);
// 检查更新后的状态
System.out.printf("--State after update--
%s
", graph.getState(updateConfig));
return updateConfig;
}
输出:
--State before update--
StateSnapshot{node=step_1, state={messages=[Step 0, Step 1]}, config=RunnableConfig{ threadId=Thread1, nextNode=human_feedback }}
--User Input--
用户选择: 'back'
--State after update--
StateSnapshot{node=step_1, state={messages=[Step 0, Step 1], human_feedback=back}, config=RunnableConfig{ threadId=Thread1, nextNode=human_feedback }}
继续执行 Graph
继续执行 Graph (interruptBefore 模式)查看完整代码
/**
* 继续执行 Graph
*/
public static void continueExecution(CompiledGraph graph, RunnableConfig updateConfig) {
// 继续执行 Graph(input 为 null,使用之前的状态)
graph.stream(null, updateConfig)
.doOnNext(event -> System.out.println(event))
.doOnError(error -> System.err.println("流错误: " + error.getMessage()))
.doOnComplete(() -> System.out.println("流完成"))
.blockLast();
}
输出:
NodeOutput{node=human_feedback, state={messages=[Step 0, Step 1], human_feedback=back}}
NodeOutput{node=step_1, state={messages=[Step 0, Step 1], human_feedback=back}}
流完成