快速入门
框架代码地址:https://github.com/alibaba/spring-ai-alibaba/tree/main/spring-ai-alibaba-graph
以下是最简单的一个 graph 示例,实现对用户问题的扩展几条相似的
实战代码可见:spring-ai-alibaba-examples 下的 graph 目录,本章代码为其 simple 模块
pom.xml
这里使用 1.0.0.3-SNAPSHOT。在定义 StateGraph 方面和 1.0.0.2 有些变动
<properties> <spring-ai-alibaba.version>1.0.0.3-SNAPSHOT</spring-ai-alibaba.version></properties>
<dependencies>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-openai</artifactId> </dependency>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId> </dependency>
<dependency> <groupId>com.alibaba.cloud.ai</groupId> <artifactId>spring-ai-alibaba-graph-core</artifactId> <version>${spring-ai-alibaba.version}</version> </dependency>
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency></dependencies>
application.yml
server: port: 8080spring: application: name: simple ai: openai: api-key: ${AIDASHSCOPEAPIKEY} base-url: https://dashscope.aliyuncs.com/compatible-mode chat: options: model: qwen-max
config
OverAllState 中存储的字段
- query:用户的问题
- expandernumber:扩展的数量
- expandercontent:扩展的内容
定义 ExpanderNode,边的连接为:START -> expander -> END
package com.spring.ai.tutorial.graph.config;
import com.alibaba.cloud.ai.graph.GraphRepresentation;import com.alibaba.cloud.ai.graph.KeyStrategy;import com.alibaba.cloud.ai.graph.KeyStrategyFactory;import com.alibaba.cloud.ai.graph.StateGraph;import com.alibaba.cloud.ai.graph.exception.GraphStateException;import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;import com.spring.ai.tutorial.graph.node.ExpanderNode;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.ai.chat.client.ChatClient;import org.springframework.context.annotation.Bean;import org.springframework.context.annotation.Configuration;
import java.util.HashMap;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.nodeasync;
@Configurationpublic class GraphConfiguration {
private static final Logger logger = LoggerFactory.getLogger(GraphConfiguration.class);
@Bean public StateGraph simpleGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException { KeyStrategyFactory keyStrategyFactory = () -> { HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
// 用户输入 keyStrategyHashMap.put("query", new ReplaceStrategy()); keyStrategyHashMap.put("expandernumber", new ReplaceStrategy()); keyStrategyHashMap.put("expandercontent", new ReplaceStrategy()); return keyStrategyHashMap; };
StateGraph stateGraph = new StateGraph(keyStrategyFactory) .addNode("expander", nodeasync(new ExpanderNode(chatClientBuilder))) .addEdge(StateGraph.START, "expander") .addEdge("expander", StateGraph.END);
// 添加 PlantUML 打印 GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "expander flow"); logger.info("\n=== expander UML Flow ==="); logger.info(representation.content()); logger.info("==================================\n");
return stateGraph; }}
node
ExpanderNode
- PromptTemplate DEFAULTPROMPTTEMPLATE:扩展文本的提示词
- ChatClient chatClient:调用 AI 模型的 client 端
- Integer NUMBER:默认扩展为 3 条相似问题
最后将 AI 模型的响应内容返回给给到字段 expandercontent 中
package com.spring.ai.tutorial.graph.node;
import com.alibaba.cloud.ai.graph.NodeOutput;import com.alibaba.cloud.ai.graph.OverAllState;import com.alibaba.cloud.ai.graph.action.NodeAction;import com.alibaba.cloud.ai.graph.streaming.StreamingChatGenerator;import org.bsc.async.AsyncGenerator;import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.model.ChatResponse;import org.springframework.ai.chat.prompt.PromptTemplate;import reactor.core.publisher.Flux;
import java.util.Arrays;import java.util.HashMap;import java.util.List;import java.util.Map;
public class ExpanderNode implements NodeAction {
private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("You are an expert at information retrieval and search optimization.\nYour task is to generate {number} different versions of the given query.\n\nEach variant must cover different perspectives or aspects of the topic,\nwhile maintaining the core intent of the original query. The goal is to\nexpand the search space and improve the chances of finding relevant information.\n\nDo not explain your choices or add any other text.\nProvide the query variants separated by newlines.\n\nOriginal query: {query}\n\nQuery variants:\n");
private final ChatClient chatClient;
private final Integer NUMBER = 3;
public ExpanderNode(ChatClient.Builder chatClientBuilder) { this.chatClient = chatClientBuilder.build(); }
@Override public Map<String, Object> apply(OverAllState state) throws Exception { String query = state.value("query", ""); Integer expanderNumber = state.value("expandernumber", this.NUMBER);
Flux<String> streamResult = this.chatClient.prompt().user((user) -> user.text(DEFAULTPROMPTTEMPLATE.getTemplate()).param("number", expanderNumber).param("query", query)).stream().content(); String result = streamResult.reduce("", (acc, item) -> acc + item).block(); List<String> queryVariants = Arrays.asList(result.split("\n"));
HashMap<String, Object> resultMap = new HashMap<>(); resultMap.put("expandercontent", queryVariants); return resultMap; }}
controller
package com.spring.ai.tutorial.graph.controller;
import com.alibaba.cloud.ai.graph.CompiledGraph;import com.alibaba.cloud.ai.graph.OverAllState;import com.alibaba.cloud.ai.graph.RunnableConfig;import com.alibaba.cloud.ai.graph.StateGraph;import com.alibaba.cloud.ai.graph.exception.GraphStateException;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.beans.factory.annotation.Qualifier;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;import java.util.Map;import java.util.Optional;
@RestController@RequestMapping("/graph")public class SimpleGraphController {
private static final Logger logger = LoggerFactory.getLogger(SimpleGraphController.class);
private final CompiledGraph compiledGraph;
public SimpleGraphController(@Qualifier("simpleGraph") StateGraph stateGraph) throws GraphStateException { this.compiledGraph = stateGraph.compile(); }
@GetMapping(value = "/expand") public Map<String, Object> expand(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?", required = false) String query, @RequestParam(value = "expandernumber", defaultValue = "3", required = false) Integer expanderNumber, @RequestParam(value = "threadid", defaultValue = "yingzi", required = false) String threadId){ RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build(); Map<String, Object> objectMap = new HashMap<>(); objectMap.put("query", query); objectMap.put("expandernumber", expanderNumber); Optional<OverAllState> invoke = this.compiledGraph.invoke(objectMap, runnableConfig); return invoke.map(OverAllState::data).orElse(new HashMap<>()); }}