Node Streaming Output
Graph connects multiple nodes together for workflow orchestration. When a node calls an AI model, that node needs to stream the AI model’s response results to the frontend.
The practical code can be found at: spring-ai-alibaba-examples under the graph directory. This chapter’s code is in the stream-node module.
pom.xml
Here we use version 1.0.0.3-SNAPSHOT. There are some changes in StateGraph definition compared to 1.0.0.2
<properties> <spring-ai-alibaba.version>1.0.0.3-SNAPSHOT</spring-ai-alibaba.version></properties>
<dependencies>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-openai</artifactId> </dependency>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId> </dependency>
<dependency> <groupId>com.alibaba.cloud.ai</groupId> <artifactId>spring-ai-alibaba-graph-core</artifactId> <version>${spring-ai-alibaba.version}</version> </dependency>
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency></dependencies>
application.yml
server: port: 8080spring: application: name: simple ai: openai: api-key: ${AIDASHSCOPEAPIKEY} base-url: https://dashscope.aliyuncs.com/compatible-mode chat: options: model: qwen-max
config
Fields stored in OverAllState:
- query: user’s question
- expandernumber: number of expansions
- expandercontent: expansion content
Define ExpanderNode, with edges connecting: START -> expander -> END
package com.spring.ai.tutorial.graph.stream.config;
import com.alibaba.cloud.ai.graph.GraphRepresentation;import com.alibaba.cloud.ai.graph.KeyStrategy;import com.alibaba.cloud.ai.graph.KeyStrategyFactory;import com.alibaba.cloud.ai.graph.StateGraph;import com.alibaba.cloud.ai.graph.exception.GraphStateException;import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;import com.spring.ai.tutorial.graph.stream.node.ExpanderNode;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.ai.chat.client.ChatClient;import org.springframework.context.annotation.Bean;import org.springframework.context.annotation.Configuration;
import java.util.HashMap;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.nodeasync;
@Configurationpublic class GraphNodeStreamConfiguration {
private static final Logger logger = LoggerFactory.getLogger(GraphNodeStreamConfiguration.class);
@Bean public StateGraph streamGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException { KeyStrategyFactory keyStrategyFactory = () -> { HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
// User input keyStrategyHashMap.put("query", new ReplaceStrategy()); keyStrategyHashMap.put("expandernumber", new ReplaceStrategy()); keyStrategyHashMap.put("expandercontent", new ReplaceStrategy()); return keyStrategyHashMap; };
StateGraph stateGraph = new StateGraph(keyStrategyFactory) .addNode("expander", nodeasync(new ExpanderNode(chatClientBuilder))) .addEdge(StateGraph.START, "expander") .addEdge("expander", StateGraph.END);
// Add PlantUML printing GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "expander flow"); logger.info("\n=== expander UML Flow ==="); logger.info(representation.content()); logger.info("==================================\n");
return stateGraph; }
}
node
ExpanderNode
package com.spring.ai.tutorial.graph.stream.node;
import com.alibaba.cloud.ai.graph.NodeOutput;import com.alibaba.cloud.ai.graph.OverAllState;import com.alibaba.cloud.ai.graph.action.NodeAction;import com.alibaba.cloud.ai.graph.streaming.StreamingChatGenerator;import org.bsc.async.AsyncGenerator;import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.model.ChatResponse;import org.springframework.ai.chat.prompt.PromptTemplate;import reactor.core.publisher.Flux;
import java.util.Arrays;import java.util.List;import java.util.Map;
public class ExpanderNode implements NodeAction {
private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("You are an expert at information retrieval and search optimization.\nYour task is to generate {number} different versions of the given query.\n\nEach variant must cover different perspectives or aspects of the topic,\nwhile maintaining the core intent of the original query. The goal is to\nexpand the search space and improve the chances of finding relevant information.\n\nDo not explain your choices or add any other text.\nProvide the query variants separated by newlines.\n\nOriginal query: {query}\n\nQuery variants:\n");
private final ChatClient chatClient;
private final Integer NUMBER = 3;
public ExpanderNode(ChatClient.Builder chatClientBuilder) { this.chatClient = chatClientBuilder.build(); }
@Override public Map<String, Object> apply(OverAllState state) throws Exception { String query = state.value("query", ""); Integer expanderNumber = state.value("expandernumber", this.NUMBER);
Flux<ChatResponse> chatResponseFlux = this.chatClient.prompt().user((user) -> user.text(DEFAULTPROMPTTEMPLATE.getTemplate()).param("number", expanderNumber).param("query", query)).stream().chatResponse();
AsyncGenerator<? extends NodeOutput> generator = StreamingChatGenerator.builder() .startingNode("expanderllmstream") .startingState(state) .mapResult(response -> { String text = response.getResult().getOutput().getText(); List<String> queryVariants = Arrays.asList(text.split("\n")); return Map.of("expandercontent", queryVariants); }).build(chatResponseFlux); return Map.of("expandercontent", generator); }}
controller
GraphStreamController
- Sinks.Many<ServerSentEvent
> sink: Receives Stream data
package com.spring.ai.tutorial.graph.stream.controller;
import com.alibaba.cloud.ai.graph.CompiledGraph;import com.alibaba.cloud.ai.graph.NodeOutput;import com.alibaba.cloud.ai.graph.RunnableConfig;import com.alibaba.cloud.ai.graph.StateGraph;import com.alibaba.cloud.ai.graph.exception.GraphStateException;import com.spring.ai.tutorial.graph.stream.controller.GraphProcess.GraphProcess;import org.bsc.async.AsyncGenerator;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.beans.factory.annotation.Qualifier;import org.springframework.http.MediaType;import org.springframework.http.codec.ServerSentEvent;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;import reactor.core.publisher.Flux;import reactor.core.publisher.Sinks;
import java.util.HashMap;import java.util.Map;
@RestController@RequestMapping("/graph/stream")public class GraphStreamController {
private static final Logger logger = LoggerFactory.getLogger(GraphStreamController.class);
private final CompiledGraph compiledGraph;
public GraphStreamController(@Qualifier("streamGraph")StateGraph stateGraph) throws GraphStateException { this.compiledGraph = stateGraph.compile(); }
@GetMapping(value = "/expand", produces = MediaType.TEXTEVENTSTREAMVALUE) public Flux<ServerSentEvent<String>> expand(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?", required = false) String query, @RequestParam(value = "expandernumber", defaultValue = "3", required = false) Integer expanderNumber, @RequestParam(value = "threadid", defaultValue = "yingzi", required = false) String threadId){ RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build(); Map<String, Object> objectMap = new HashMap<>(); objectMap.put("query", query); objectMap.put("expandernumber", expanderNumber);
GraphProcess graphProcess = new GraphProcess(this.compiledGraph); Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer(); AsyncGenerator<NodeOutput> resultFuture = compiledGraph.stream(objectMap, runnableConfig); graphProcess.processStream(resultFuture, sink);
return sink.asFlux() .doOnCancel(() -> logger.info("Client disconnected from stream")) .doOnError(e -> logger.error("Error occurred during streaming", e)); }
}
GraphProcess
- ExecutorService executor: Configure thread pool to get stream flow
Write the results to the sink
package com.spring.ai.tutorial.graph.stream.controller.GraphProcess;
import com.alibaba.cloud.ai.graph.CompiledGraph;import com.alibaba.cloud.ai.graph.NodeOutput;import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;import com.alibaba.fastjson.JSON;import com.alibaba.fastjson.JSONObject;import org.bsc.async.AsyncGenerator;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.http.codec.ServerSentEvent;import reactor.core.publisher.Sinks;
import java.util.Map;import java.util.concurrent.CompletionException;import java.util.concurrent.ExecutorService;import java.util.concurrent.Executors;
public class GraphProcess {
private static final Logger logger = LoggerFactory.getLogger(GraphProcess.class);
private final ExecutorService executor = Executors.newSingleThreadExecutor();
private CompiledGraph compiledGraph;
public GraphProcess(CompiledGraph compiledGraph) { this.compiledGraph = compiledGraph; }
public void processStream(AsyncGenerator<NodeOutput> generator, Sinks.Many<ServerSentEvent<String>> sink) { executor.submit(() -> { generator.forEachAsync(output -> { try { logger.info("output = {}", output); String nodeName = output.node(); String content; if (output instanceof StreamingOutput streamingOutput) { content = JSON.toJSONString(Map.of(nodeName, streamingOutput.chunk())); } else { JSONObject nodeOutput = new JSONObject(); nodeOutput.put("data", output.state().data()); nodeOutput.put("node", nodeName); content = JSON.toJSONString(nodeOutput); } sink.tryEmitNext(ServerSentEvent.builder(content).build()); } catch (Exception e) { throw new CompletionException(e); } }).thenAccept(v -> { // Normal completion sink.tryEmitComplete(); }).exceptionally(e -> { sink.tryEmitError(e); return null; }); }); }}