跳到主要内容

分配MCP工具给指定节点

分配指定的 MCP 给指定的 node 节点

实战代码可见:spring-ai-alibaba-examples 下的 graph 目录,本章代码为其 mcp-node 模块

pom.xml

这里使用 1.0.0.3-SNAPSHOT。在定义 StateGraph 方面和 1.0.0.2 有些变动

<properties>
<spring-ai-alibaba.version>1.0.0.3-SNAPSHOT</spring-ai-alibaba.version>
</properties>

<dependencies>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-openai</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
</dependency>

<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-graph-core</artifactId>
<version>${spring-ai-alibaba.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
</dependencies>

application.yml

注意 spring.ai.graph.nodes 下的配置,node 配置对应 mcp 服务的映射

server:
port: 8080
spring:
application:
name: mcp-node
ai:
openai:
api-key: ${AIDASHSCOPEAPIKEY}
base-url: https://dashscope.aliyuncs.com/compatible-mode
chat:
options:
model: qwen-max
mcp:
client:
enabled: true
name: my-mcp-client
version: 1.0.0
request-timeout: 30s
type: ASYNC # or ASYNC for reactive applications
sse:
connections:
server1:
url: http://localhost:19000

graph:
nodes:
node2servers:
mcp-node:
- server1

config

McpNodeProperties

node 配置对应 mcp 服务的映射类

package com.spring.ai.tutorial.graph.mcp.config;

import org.springframework.boot.context.properties.ConfigurationProperties;

import java.util.Map;
import java.util.Set;

@ConfigurationProperties(prefix = McpNodeProperties.PREFIX)
public class McpNodeProperties {

public static final String PREFIX = "spring.ai.graph.nodes";

private Map<String, Set<String>> node2servers;

public Map<String, Set<String>> getNode2servers() {
return node2servers;
}

public void setNode2servers(Map<String, Set<String>> node2servers) {
this.node2servers = node2servers;
}
}

McpGaphConfiguration

注入 McpClientToolCallbackProvider,提供给 McpNode

package com.spring.ai.tutorial.graph.mcp.config;

import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.spring.ai.tutorial.graph.mcp.node.McpNode;
import com.spring.ai.tutorial.graph.mcp.tool.McpClientToolCallbackProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.HashMap;

import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.nodeasync;

@Configuration
@EnableConfigurationProperties({ McpNodeProperties.class })
public class McpGaphConfiguration {

private static final Logger logger = LoggerFactory.getLogger(McpGaphConfiguration.class);

@Autowired
private McpClientToolCallbackProvider mcpClientToolCallbackProvider;

@Bean
public StateGraph mcpGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException {
KeyStrategyFactory keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();

// 用户输入
keyStrategyHashMap.put("query", new ReplaceStrategy());
keyStrategyHashMap.put("mcpcontent", new ReplaceStrategy());
return keyStrategyHashMap;
};

StateGraph stateGraph = new StateGraph(keyStrategyFactory)
.addNode("mcp", nodeasync(new McpNode(chatClientBuilder, mcpClientToolCallbackProvider)))

.addEdge(StateGraph.START, "mcp")
.addEdge("mcp", StateGraph.END);

// 添加 PlantUML 打印
GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML,
"mcp flow");
logger.info("\n=== mcp UML Flow ===");
logger.info(representation.content());
logger.info("==================================\n");

return stateGraph;
}
}

tool

McpClientToolCallbackProvider

根据节点名称,匹配对应的 MCP 提供的 ToolCallback

package com.spring.ai.tutorial.graph.mcp.tool;

import com.spring.ai.tutorial.graph.mcp.config.McpNodeProperties;
import org.apache.commons.compress.utils.Lists;
import org.glassfish.jersey.internal.guava.Sets;
import org.springframework.ai.mcp.McpToolUtils;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.Set;

@Service
public class McpClientToolCallbackProvider {

private final ToolCallbackProvider toolCallbackProvider;

private final McpClientCommonProperties commonProperties;

private final McpNodeProperties mcpNodeProperties;

public McpClientToolCallbackProvider(ToolCallbackProvider toolCallbackProvider,
McpClientCommonProperties commonProperties, McpNodeProperties mcpNodeProperties) {
this.toolCallbackProvider = toolCallbackProvider;
this.commonProperties = commonProperties;
this.mcpNodeProperties = mcpNodeProperties;
}

public Set<ToolCallback> findToolCallbacks(String nodeName) {
Set<ToolCallback> defineCallback = Sets.newHashSet();
Set<String> mcpClients = mcpNodeProperties.getNode2servers().get(nodeName);
if (mcpClients == null || mcpClients.isEmpty()) {
return defineCallback;
}

List<String> exceptMcpClientNames = Lists.newArrayList();
for (String mcpClient : mcpClients) {
// my-mcp-client
String name = commonProperties.getName();
// mymcpclientserver1
String prefixedMcpClientName = McpToolUtils.prefixedToolName(name, mcpClient);
exceptMcpClientNames.add(prefixedMcpClientName);
}

ToolCallback[] toolCallbacks = toolCallbackProvider.getToolCallbacks();
for (ToolCallback toolCallback : toolCallbacks) {
ToolDefinition toolDefinition = toolCallback.getToolDefinition();
// mymcpclientserver1getCityTimeMethod
String name = toolDefinition.name();
for (String exceptMcpClientName : exceptMcpClientNames) {
if (name.startsWith(exceptMcpClientName)) {
defineCallback.add(toolCallback);
}
}
}
return defineCallback;
}
}

node

McpNode

通过 McpClientToolCallbackProvider 找到当前节点的 ToolCallback

package com.spring.ai.tutorial.graph.mcp.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.spring.ai.tutorial.graph.mcp.tool.McpClientToolCallbackProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.tool.ToolCallback;
import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class McpNode implements NodeAction {

private static final Logger logger = LoggerFactory.getLogger(McpNode.class);

private static final String NODENAME = "mcp-node";

private final ChatClient chatClient;

public McpNode(ChatClient.Builder chatClientBuilder, McpClientToolCallbackProvider mcpClientToolCallbackProvider) {
Set<ToolCallback> toolCallbacks = mcpClientToolCallbackProvider.findToolCallbacks(NODENAME);
for (ToolCallback toolCallback : toolCallbacks) {
logger.info("Mcp Node load ToolCallback: " + toolCallback.getToolDefinition().name());
}

this.chatClient = chatClientBuilder
.defaultToolCallbacks(toolCallbacks.toArray(ToolCallback[]::new))
.build();
}


@Override
public Map<String, Object> apply(OverAllState state) {
String query = state.value("query", "");
Flux<String> streamResult = chatClient.prompt(query).stream().content();
String result = streamResult.reduce("", (acc, item) -> acc + item).block();

HashMap<String, Object> resultMap = new HashMap<>();
resultMap.put("mcpcontent", result);

return resultMap;
}
}

controller

McpController

package com.spring.ai.tutorial.graph.mcp.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

@RestController
@RequestMapping("/graph/mcp")
public class McpController {

private static final Logger logger = LoggerFactory.getLogger(McpController.class);

private final CompiledGraph compiledGraph;

public McpController(@Qualifier("mcpGraph") StateGraph stateGraph) throws GraphStateException {
this.compiledGraph = stateGraph.compile();
}

@GetMapping("/call")
public Map<String, Object> call(@RequestParam(value = "query", defaultValue = "北京时间现在几点钟", required = false) String query,
@RequestParam(value = "threadid", defaultValue = "yingzi", required = false) String threadId){
RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();
Map<String, Object> objectMap = new HashMap<>();
objectMap.put("query", query);
Optional<OverAllState> invoke = this.compiledGraph.invoke(objectMap, runnableConfig);
return invoke.map(OverAllState::data).orElse(new HashMap<>());
}

}

MCP Server 服务提供

提供一个 MCP Server 服务,远端 or 本地都可以

这里本地启动一个MCP Server,提供一个时间服务

效果

启动本地 MCP Server,提供时间服务

调用接口,触发本地端 MCP Server 提供的时间服务

Spring AI Alibaba 开源项目基于 Spring AI 构建,是阿里云通义系列模型及服务在 Java AI 应用开发领域的最佳实践,提供高层次的 AI API 抽象与云原生基础设施集成方案,帮助开发者快速构建 AI 应用。