spring-ai-mcp-server-patterns by giuseppe-trisciuoglio/developer-kit
npx skills add https://github.com/giuseppe-trisciuoglio/developer-kit --skill spring-ai-mcp-server-patterns使用 Spring AI 实现模型上下文协议(MCP)服务器,通过 Spring 原生的 AI 抽象,利用标准化工具、资源和提示模板来扩展 AI 能力。
模型上下文协议(MCP)是一个标准化协议,用于将 AI 应用程序连接到外部数据源和工具。Spring AI 提供了原生支持,用于构建 MCP 服务器,将 Spring 组件作为可调用的工具、资源和 AI 模型的提示模板进行暴露。
本技能涵盖了使用 Spring AI 创建生产就绪的 MCP 服务器的实现模式,包括:
@Tool 注解将 Spring 方法暴露为 AI 可调用的函数@PromptTemplate 注解的可重用提示模板在构建以下应用时使用此技能:
按照以下步骤使用 Spring AI 实现 MCP 服务器:
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
pom.xml 或 build.gradle 中application.properties 中配置 AI 模型(OpenAI、Anthropic 等)@EnableMcpServer 注解启用 MCP 服务器@Component)@Tool(description = "...") 注解方法@ToolParam 记录参数以便 AI 理解@PromptTemplate 创建可重用的提示@PromptParam 定义参数Prompt 对象stdio、http 或 sseapplication.yml 中配置传输属性创建一个具有函数调用的简单 MCP 服务器:
@SpringBootApplication
@EnableMcpServer
public class WeatherMcpApplication {
public static void main(String[] args) {
SpringApplication.run(WeatherMcpApplication.class, args);
}
}
@Component
public class WeatherTools {
@Tool(description = "获取城市的当前天气")
public WeatherData getWeather(@ToolParam("城市名称") String city) {
// 实现
return new WeatherData(city, "Sunny", 22.5);
}
}
在 application.properties 中配置函数调用:
spring.ai.openai.api-key=${OPENAI_API_KEY}
spring.ai.mcp.enabled=true
spring.ai.mcp.transport=stdio
将 Spring AI MCP 依赖项添加到你的项目中:
Maven:
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-mcp-server</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId>
<version>1.0.0</version>
</dependency>
Gradle:
dependencies {
implementation 'org.springframework.ai:spring-ai-mcp-server:1.0.0'
implementation 'org.springframework.ai:spring-ai-starter-model-openai:1.0.0'
}
或使用 Spring Boot starter:
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-mcp-starter</artifactId>
<version>1.0.0</version>
</dependency>
MCP 通过 Spring AI 抽象标准化了 AI 应用程序的连接:
工具 :使用 @Tool 注解的可执行函数
资源 :可通过 Spring 组件访问的数据源
提示 :基于模板的交互,使用 @PromptTemplate
传输 :Spring 管理的通信通道
AI 应用程序 ←→ MCP 客户端 ←→ Spring AI ←→ MCP 服务器 ←→ Spring 服务
@Tool:将方法声明为 AI 模型的可调用函数@ToolParam:记录参数用途以便 AI 理解@PromptTemplate:定义可重用的提示模式@Model:指定 AI 模型配置使用 Spring AI 的声明式方法创建工具:
@Component
public class DatabaseTools {
private final JdbcTemplate jdbcTemplate;
public DatabaseTools(JdbcTemplate jdbcTemplate) {
this.jdbcTemplate = jdbcTemplate;
}
@Tool(description = "执行安全的只读 SQL 查询")
public List<Map<String, Object>> executeQuery(
@ToolParam("SQL SELECT 查询") String query,
@ToolParam(value = "查询参数", required = false)
Map<String, Object> params) {
// 验证查询是否为只读
if (!query.trim().toUpperCase().startsWith("SELECT")) {
throw new IllegalArgumentException("仅允许 SELECT 查询");
}
return jdbcTemplate.queryForList(query, params);
}
@Tool(description = "获取表结构信息")
public TableSchema getTableSchema(
@ToolParam("表名") String tableName) {
String sql = "SELECT column_name, data_type " +
"FROM information_schema.columns " +
"WHERE table_name = ?";
List<Map<String, Object>> columns = jdbcTemplate.queryForList(sql, tableName);
return new TableSchema(tableName, columns);
}
}
record TableSchema(String tableName, List<Map<String, Object>> columns) {}
@Component
public class ApiTools {
private final WebClient webClient;
public ApiTools(WebClient.Builder webClientBuilder) {
this.webClient = webClientBuilder.build();
}
@Tool(description = "向 API 发起 HTTP GET 请求")
public ApiResponse callApi(
@ToolParam("API URL") String url,
@ToolParam(value = "以 JSON 字符串表示的请求头", required = false)
String headersJson) {
// 验证 URL
try {
new URL(url);
} catch (MalformedURLException e) {
throw new IllegalArgumentException("URL 格式无效");
}
// 如果提供了请求头则进行解析
HttpHeaders headers = new HttpHeaders();
if (headersJson != null && !headersJson.isBlank()) {
try {
Map<String, String> headersMap = new ObjectMapper()
.readValue(headersJson, Map.class);
headersMap.forEach(headers::add);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("请求头 JSON 无效");
}
}
return webClient.get()
.uri(url)
.headers(h -> h.addAll(headers))
.retrieve()
.bodyToMono(ApiResponse.class)
.block();
}
}
record ApiResponse(int status, Map<String, Object> body, HttpHeaders headers) {}
使用 Spring AI 创建可重用的提示模板:
@Component
public class CodeReviewPrompts {
@PromptTemplate(
name = "java-code-review",
description = "审查 Java 代码的最佳实践和问题"
)
public Prompt createJavaCodeReviewPrompt(
@PromptParam("code") String code,
@PromptParam(value = "focusAreas", required = false)
List<String> focusAreas) {
String focus = focusAreas != null ?
String.join(", ", focusAreas) :
"general best practices";
return Prompt.builder()
.system("You are an expert Java code reviewer with 20 years of experience.")
.user("""
Review the following Java code for {focus}:
```java
{code}
```
Provide feedback in the following format:
1. Critical issues (must fix)
2. Warnings (should fix)
3. Suggestions (consider improving)
4. Positive aspects
Be specific and provide code examples where relevant.
""".replace("{code}", code).replace("{focus}", focus))
.build();
}
@PromptTemplate(
name = "generate-unit-tests",
description = "为 Java 代码生成全面的单元测试"
)
public Prompt createTestGenerationPrompt(
@PromptParam("code") String code,
@PromptParam("className") String className,
@PromptParam(value = "testingFramework", required = false)
String framework) {
String testFramework = framework != null ? framework : "JUnit 5";
return Prompt.builder()
.system("You are an expert in test-driven development.")
.user("""
Generate comprehensive unit tests for the following Java class using {testFramework}:
```java
{code}
```
Class: {className}
Requirements:
1. Test all public methods
2. Include edge cases and boundary conditions
3. Use appropriate assertions
4. Follow AAA pattern (Arrange, Act, Assert)
5. Include test method naming best practices
6. Mock external dependencies
""".replace("{code}", code)
.replace("{className}", className)
.replace("{testFramework}", testFramework))
.build();
}
}
低级函数调用集成:
@Configuration
public class FunctionConfig {
@Bean
public FunctionCallback weatherFunction() {
return FunctionCallback.builder()
.function("getCurrentWeather", new WeatherService())
.description("获取当前位置的当前天气")
.inputType(WeatherRequest.class)
.build();
}
@Bean
public FunctionCallback calculatorFunction() {
return FunctionCallbackWrapper.builder(new Calculator())
.withName("calculate")
.withDescription("执行数学计算")
.build();
}
}
class WeatherService implements Function<WeatherRequest, WeatherResponse> {
@Override
public WeatherResponse apply(WeatherRequest request) {
// 调用天气 API
return new WeatherResponse(request.location(), 72, "Sunny");
}
}
record WeatherRequest(String location) {}
record WeatherResponse(String location, double temperature, String condition) {}
class Calculator implements BiFunction<String, Map<String, Object>, String> {
@Override
public String apply(String functionName, Map<String, Object> args) {
// 基于参数执行计算
return "result";
}
}
使用 Spring Boot 自动配置设置 MCP 服务器:
@Configuration
@AutoConfigureAfter({WebMvcAutoConfiguration.class})
@ConditionalOnClass({McpServer.class, ChatModel.class})
@ConditionalOnProperty(name = "spring.ai.mcp.enabled", havingValue = "true", matchIfMissing = true)
public class McpAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public McpServerProperties mcpServerProperties() {
return new McpServerProperties();
}
@Bean
@ConditionalOnMissingBean
public McpServer mcpServer(
List<FunctionCallback> functionCallbacks,
List<PromptTemplate> promptTemplates,
McpServerProperties properties
) {
McpServer.Builder builder = McpServer.builder()
.serverInfo("spring-ai-mcp", "1.0.0")
.transport(properties.getTransport().create());
// 将函数回调注册为工具
functionCallbacks.forEach(callback ->
builder.tool(Tool.fromFunctionCallback(callback))
);
// 注册提示模板
promptTemplates.forEach(template ->
builder.prompt(Prompt.fromTemplate(template))
);
return builder.build();
}
@Bean
@ConditionalOnProperty(name = "spring.ai.mcp.actuator.enabled", havingValue = "true")
public McpHealthIndicator mcpHealthIndicator(McpServer mcpServer) {
return new McpHealthIndicator(mcpServer);
}
}
@ConfigurationProperties(prefix = "spring.ai.mcp")
public class McpServerProperties {
private boolean enabled = true;
private TransportConfig transport = new TransportConfig();
private ActuatorConfig actuator = new ActuatorConfig();
// Getters and setters
public static class TransportConfig {
private TransportType type = TransportType.STDIO;
private HttpConfig http = new HttpConfig();
public Transport create() {
return switch (type) {
case STDIO -> new StdioTransport();
case HTTP -> new HttpTransport(http.getPort());
case SSE -> new SseTransport(http.getPort(), http.getPath());
};
}
}
public static class HttpConfig {
private int port = 8080;
private String path = "/mcp";
// Getters and setters
}
public static class ActuatorConfig {
private boolean enabled = true;
// Getters and setters
}
public enum TransportType {
STDIO, HTTP, SSE
}
}
在 application.yml 中配置 MCP 服务器:
spring:
ai:
mcp:
enabled: true
transport:
type: stdio # 选项:stdio, http, sse
http:
port: 8080
path: /mcp
actuator:
enabled: true
tools:
package-scan: com.example.tools
prompts:
package-scan: com.example.prompts
security:
enabled: true
allowed-tools:
- getWeather
- executeQuery
admin-tools:
- admin_*
用于高级配置:
@Configuration
public class CustomMcpConfig {
@Bean
public McpServerCustomizer mcpServerCustomizer() {
return server -> {
server.addToolInterceptor((tool, args, chain) -> {
log.info("Executing tool: {}", tool.name());
long start = System.currentTimeMillis();
Object result = chain.execute(tool, args);
long duration = System.currentTimeMillis() - start;
log.info("Tool {} executed in {}ms", tool.name(), duration);
metrics.recordToolExecution(tool.name(), duration);
return result;
});
};
}
@Bean
public ToolFilter toolFilter(SecurityService securityService) {
return (tool, context) -> {
User user = securityService.getCurrentUser();
if (tool.name().startsWith("admin_")) {
return user.hasRole("ADMIN");
}
return securityService.isToolAllowed(user, tool.name());
};
}
}
@Service
public class SecurityService {
public boolean isToolAllowed(User user, String toolName) {
// 实现工具访问控制逻辑
return true;
}
}
使用 Spring Security 实现安全的工具执行:
@Component
public class SecureToolExecutor {
private final McpServer mcpServer;
private final SecurityContextHolder strategy;
public SecureToolExecutor(McpServer mcpServer, SecurityContextHolder strategy) {
this.mcpServer = mcpServer;
this.strategy = strategy;
}
public ToolResult executeTool(String toolName, Map<String, Object> arguments) {
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (!(auth instanceof UserAuthentication userAuth)) {
throw new AccessDeniedException("User not authenticated");
}
// 检查工具权限
if (!hasToolPermission(userAuth.getUser(), toolName)) {
throw new AccessDeniedException("Tool not allowed: " + toolName);
}
// 针对注入模式验证参数
validateArguments(arguments);
// 执行并记录审计日志
logToolExecution(userAuth.getUser(), toolName, arguments);
try {
ToolResult result = mcpServer.executeTool(toolName, arguments);
logToolSuccess(userAuth.getUser(), toolName);
return result;
} catch (Exception e) {
logToolFailure(userAuth.getUser(), toolName, e);
throw new ToolExecutionException("Tool execution failed", e);
}
}
private boolean hasToolPermission(User user, String toolName) {
// 基于用户角色和工具敏感性实现权限逻辑
return user.getAuthorities().stream()
.anyMatch(auth -> auth.getAuthority().equals("TOOL_" + toolName) ||
auth.getAuthority().equals("ROLE_ADMIN"));
}
private void validateArguments(Map<String, Object> arguments) {
// 实现参数验证以防止注入攻击
arguments.forEach((key, value) -> {
if (value instanceof String str) {
if (str.contains(";") || str.contains("--")) {
throw new IllegalArgumentException("Invalid characters in argument: " + key);
}
}
});
}
private void logToolExecution(User user, String toolName, Map<String, Object> arguments) {
// 实现审计日志记录
}
private void logToolSuccess(User user, String toolName) {
// 记录成功执行
}
private void logToolFailure(User user, String toolName, Exception e) {
// 记录失败执行
}
}
class ToolExecutionException extends RuntimeException {
public ToolExecutionException(String message, Throwable cause) {
super(message, cause);
}
}
使用 Spring 的验证框架:
@Component
public class ValidatedTools {
@Tool(description = "使用验证处理用户数据")
@Validated
public ProcessingResult processUserData(
@ToolParam("要处理的用户数据") @Valid UserData data) {
// 实现
return new ProcessingResult("success", data);
}
}
record UserData(
@NotBlank(message = "Name is required")
@Size(max = 100, message = "Name must be 100 characters or less")
String name,
@NotNull(message = "Age is required")
@Min(value = 18, message = "Must be 18 or older")
@Max(value = 120, message = "Age must be realistic")
Integer age,
@NotBlank(message = "Email is required")
@Email(message = "Invalid email format")
String email
) {}
// 敏感操作的自定义验证器
@Component
public class SensitiveOperationValidator {
public void validateOperation(String operation, User user, Map<String, Object> params) {
if (isSensitiveOperation(operation)) {
requireAdditionalAuthentication(user);
validateOperationLimits(user, operation);
logSensitiveOperation(user, operation, params);
}
}
private boolean isSensitiveOperation(String operation) {
return operation.startsWith("delete") || operation.startsWith("update");
}
private void requireAdditionalAuthentication(User user) {
// 实现 MFA 或重新认证
}
private void validateOperationLimits(User user, String operation) {
// 检查速率限制和配额
}
private void logSensitiveOperation(User user, String operation, Map<String, Object> params) {
// 安全的审计日志记录
}
}
实现全面的错误处理:
@ControllerAdvice
public class McpExceptionHandler {
@ExceptionHandler(ToolExecutionException.class)
public ResponseEntity<ErrorResponse> handleToolExecutionException(
ToolExecutionException ex, WebRequest request) {
ErrorResponse error = ErrorResponse.builder()
.timestamp(LocalDateTime.now())
.status(HttpStatus.INTERNAL_SERVER_ERROR.value())
.error("Tool Execution Failed")
.message(ex.getMessage())
.path(((ServletWebRequest) request).getRequest().getRequestURI())
.build();
log.error("Tool execution failed: {}", ex.getMessage(), ex);
return new ResponseEntity<>(error, HttpStatus.INTERNAL_SERVER_ERROR);
}
@ExceptionHandler(AccessDeniedException.class)
public ResponseEntity<ErrorResponse> handleAccessDenied(
AccessDeniedException ex, WebRequest request) {
ErrorResponse error = ErrorResponse.builder()
.timestamp(LocalDateTime.now())
.status(HttpStatus.FORBIDDEN.value())
.error("Access Denied")
.message("You do not have permission to execute this tool")
.path(((ServletWebRequest) request).getRequest().getRequestURI())
.build();
log.warn("Access denied: {}", ex.getMessage());
return new ResponseEntity<>(error, HttpStatus.FORBIDDEN);
}
@ExceptionHandler(IllegalArgumentException.class)
public ResponseEntity<ErrorResponse> handleValidationError(
IllegalArgumentException ex, WebRequest request) {
ErrorResponse error = ErrorResponse.builder()
.timestamp(LocalDateTime.now())
.status(HttpStatus.BAD_REQUEST.value())
.error("Validation Error")
.message(ex.getMessage())
.path(((ServletWebRequest) request).getRequest().getRequestURI())
.build();
return new ResponseEntity<>(error, HttpStatus.BAD_REQUEST);
}
@ExceptionHandler(Exception.class)
public ResponseEntity<ErrorResponse> handleGenericException(
Exception ex, WebRequest request) {
ErrorResponse error = ErrorResponse.builder()
.timestamp(LocalDateTime.now())
.status(HttpStatus.INTERNAL_SERVER_ERROR.value())
.error("Internal Server Error")
.message("An unexpected error occurred")
.path(((ServletWebRequest) request).getRequest().getRequestURI())
.build();
log.error("Unexpected error: {}", ex.getMessage(), ex);
return new ResponseEntity<>(error, HttpStatus.INTERNAL_SERVER_ERROR);
}
@Data
@Builder
static class ErrorResponse {
private LocalDateTime timestamp;
private int status;
private String error;
private String message;
private String path;
}
}
在运行时注册工具:
public class DynamicToolRegistry {
private final McpServer mcpServer;
private final Map<String, ToolRegistration> registeredTools = new ConcurrentHashMap<>();
public DynamicToolRegistry(McpServer mcpServer) {
this.mcpServer = mcpServer;
}
public void registerTool(ToolRegistration registration) {
registeredTools.put(registration.getId(), registration);
Tool tool = Tool.builder()
.name(registration.getName())
.description(registration.getDescription())
.inputSchema(registration.getInputSchema())
.function(args -> executeDynamicTool(registration.getId(), args))
.build();
mcpServer.addTool(tool);
log.info("Registered dynamic tool: {}", registration.getName());
}
public void unregisterTool(String toolId) {
ToolRegistration registration = registeredTools.remove(toolId);
if (registration != null) {
mcpServer.removeTool(registration.getName());
log.info("Unregistered dynamic tool: {}", registration.getName());
}
}
private Object executeDynamicTool(String toolId, Map<String, Object> args) {
ToolRegistration registration = registeredTools.get(toolId);
if (registration == null) {
throw new IllegalStateException("Tool not found: " + toolId);
}
// 基于注册类型执行
return switch (registration.getType()) {
case GROOVY_SCRIPT -> executeGroovyScript(registration, args);
case SPRING_BEAN -> executeSpringBeanMethod(registration, args);
case HTTP_ENDPOINT -> callHttpEndpoint(registration, args);
};
}
private Object executeGroovyScript(ToolRegistration registration, Map<String, Object> args) {
// 实现 Groovy 脚本执行
return null;
}
private Object executeSpringBeanMethod(ToolRegistration registration, Map<String, Object> args) {
// 实现 Spring bean 方法调用
return null;
}
private Object callHttpEndpoint(ToolRegistration registration, Map<String, Object> args) {
// 实现 HTTP 调用
return null;
}
}
@Data
@Builder
class ToolRegistration {
private String id;
private String name;
private String description;
private Map<String, Object> inputSchema;
private ToolType type;
private String target; // 脚本、bean 名称或 URL
private Map<String, String> metadata;
}
enum ToolType {
GROOVY_SCRIPT,
SPRING_BEAN,
HTTP_ENDPOINT
}
支持多个 AI 模型:
@Configuration
public class MultiModelConfig {
@Bean
@Primary
public ChatModel primaryChatModel(@Value("${spring.ai.primary.model}") String modelName) {
return switch (modelName) {
case "gpt-4" -> new OpenAiChatModel(OpenAiApi.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build());
case "claude" -> new AnthropicChatModel(AnthropicApi.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.build());
default -> throw new IllegalArgumentException("Unsupported model: " + modelName);
};
}
@Bean
public Map<String, ChatModel> allChatModels() {
Map<String, ChatModel> models = new HashMap<>();
models.put("gpt-4", new OpenAiChatModel(OpenAiApi.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build()));
models.put("gpt-3.5", new OpenAiChatModel(OpenAiApi.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.model("gpt-3.5-turbo")
.build()));
models.put("claude-opus", new AnthropicChatModel(AnthropicApi.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.model("claude-3-opus-20240229")
.build()));
return models;
}
@Bean
public ModelSelector modelSelector(Map<String, ChatModel> models) {
return new SpringAiModelSelector(models);
}
}
@Component
public class SpringAiModelSelector implements ModelSelector {
private final Map<String, ChatModel> models;
public SpringAiModelSelector(Map<String, ChatModel> models) {
this.models = models;
}
@Override
public ChatModel selectModel(Prompt prompt, Map<String, Object> context) {
// 基于提示复杂性、成本、延迟要求选择模型
String modelName = determineBestModel(prompt, context);
return models.get(modelName);
}
private String determineBestModel(Prompt prompt, Map<String, Object> context) {
// 实现模型选择逻辑
// 考虑:提示长度、复杂性、成本约束、延迟要求
return "gpt-4";
}
}
为工具和提示实现缓存:
@Configuration
@EnableCaching
public class McpCacheConfig {
@Bean
public CacheManager cacheManager() {
return new ConcurrentMapCacheManager(
"tool-results",
"prompt-templates",
"function-callbacks"
);
}
}
@Component
public class CachedToolExecutor {
private final McpServer mcpServer;
public CachedToolExecutor(McpServer mcpServer) {
this.mcpServer = mcpServer;
}
@Cacheable(
value = "tool-results",
key = "#toolName + '_' + #args.hashCode()",
unless = "#result.isCacheable() == false"
)
public ToolResult executeTool(String toolName, Map<String, Object> args) {
return mcpServer.executeTool(toolName, args);
}
@CacheEvict(value = "tool-results", allEntries = true)
public void clearToolCache() {
// 当工具更新时清除缓存
}
@Cacheable(value = "prompt-templates", key = "#templateName")
public PromptTemplate getPromptTemplate(String templateName) {
return mcpServer.getPromptTemplate(templateName);
}
}
@SpringBootTest
class DatabaseToolsTest {
@Autowired
private DatabaseTools databaseTools;
@MockBean
private JdbcTemplate jdbcTemplate;
@Test
void testExecuteQuery_Success() {
// Given
String query = "SELECT * FROM users WHERE id = ?";
Map<String, Object> params = Map.of("id", 1);
List<Map<String, Object>> expectedResults = List.of(
Map.of("id", 1, "name", "John")
);
when(jdbcTemplate.queryForList(anyString(), anyMap()))
.thenReturn(expectedResults);
// When
List<Map<String, Object>> results = databaseTools.executeQuery(query, params);
// Then
assertThat(results).isEqualTo(expectedResults);
verify(jdbcTemplate).queryForList(query, params);
}
@Test
void testExecuteQuery_InvalidQuery_ThrowsException() {
// Given
String query = "DROP TABLE users";
// When & Then
assertThatThrownBy(() -> databaseTools.executeQuery(query, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Only SELECT queries are allowed");
verifyNoInteractions(jdbcTemplate);
}
@Test
void testGetTableSchema_Success() {
// Given
String tableName = "users";
List<Map<String, Object>> columns = List.of(
Map.of("column_name", "id", "data_type", "integer"),
Map.of("column_name", "name", "data_type", "varchar")
);
when(jdbcTemplate.queryForList(anyString(), eq(tableName)))
.thenReturn(columns);
// When
TableSchema schema = databaseTools.getTableSchema(tableName);
// Then
assertThat(schema.tableName()).isEqualTo(tableName);
assertThat(schema.columns()).isEqualTo(columns);
}
}
@SpringBootTest
@AutoConfigureMockMvc
class McpServerIntegrationTest {
@Autowired
private MockMvc mockMvc;
@Autowired
private McpServer mcpServer;
@MockBean
private DatabaseTools databaseTools;
@Test
void testExecuteTool_Success() throws Exception {
// Given
String toolName = "executeQuery";
Map<String, Object> args = Map.of(
"query", "SELECT * FROM users",
"params", Map.of()
);
List<Map<String, Object>> expectedResult = List.of(
Map.of("id", 1, "name", "Test User")
);
when(databaseTools.executeQuery(anyString(), anyMap()))
.thenReturn(expectedResult);
// When & Then
mockMvc.perform(post("/mcp/tools/executeQuery")
.contentType(MediaType.APPLICATION_JSON)
.content(new ObjectMapper().writeValueAsString(args)))
.andExpect(status().isOk())
.andExpect(jsonPath("$.result").isArray())
.andExpect(jsonPath("$.result[0].id").value(1));
}
@Test
void testListTools_Success() throws Exception {
// When & Then
mockMvc.perform(get("/mcp/tools"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.tools").isArray());
}
@Test
void testHealthEndpoint() throws Exception {
// When & Then
mockMvc.perform(get("/actuator/health/mcp"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.status").value("UP"));
}
}
@SpringBootTest
@Testcontainers
@AutoConfigureMockMvc
class McpServerIntegrationTest {
@Container
static PostgreSQLContainer<?> postgres = new PostgreSQLContainer<>("postgres:15")
.withDatabaseName("testdb")
.withUsername("test")
.withPassword("test");
@DynamicPropertySource
static void properties(DynamicPropertyRegistry registry) {
registry.add("spring.datasource.url", postgres::getJdbcUrl);
registry.add("spring.datasource.username", postgres::getUsername);
registry.add("spring.datasource.password", postgres::getPassword);
}
@Autowired
private MockMvc mockMvc;
@Test
void testDatabaseToolWithRealDatabase() throws Exception {
// Given
String query = "SELECT current_database(), current_user";
Map<String, Object> request = Map.of(
"tool", "executeQuery",
"arguments", Map.of("query", query)
);
// When & Then
mockMvc.perform(post("/mcp/tools/executeQuery")
.contentType(MediaType.APPLICATION_JSON)
.content(new ObjectMapper().writeValueAsString(request)))
.andExpect(status().isOk())
.andExpect(jsonPath("$.success").value(true))
.andExpect(jsonPath("$.data[0].current_database").value("testdb"))
.andExpect(jsonPath("$.data[0].current_user").value("test"));
}
}
@WebMvcTest 进行测试(切片测试)@WebMvcTest(controllers = McpController.class)
class McpControllerSliceTest {
@Autowired
private MockMvc mockMvc;
@MockBean
private McpServer mcpServer;
@MockBean
private ToolRegistry toolRegistry;
@Test
void testListToolsEndpoint() throws Exception {
// Given
Tool tool1 = Tool.builder().name("tool1").description("Tool 1").build();
Tool tool2 = Tool.builder().name("tool2").description("Tool 2").build();
when(toolRegistry.listTools()).thenReturn(List.of(tool1, tool2));
// When & Then
mockMvc.perform(get("/mcp/tools"))
Implement Model Context Protocol (MCP) servers with Spring AI to extend AI capabilities with standardized tools, resources, and prompt templates using Spring's native AI abstractions.
The Model Context Protocol (MCP) is a standardized protocol for connecting AI applications to external data sources and tools. Spring AI provides native support for building MCP servers that expose Spring components as callable tools, resources, and prompt templates for AI models.
This skill covers the implementation patterns for creating production-ready MCP servers using Spring AI, including:
@Tool annotation@PromptTemplate annotationUse this skill when building:
Follow these steps to implement an MCP server with Spring AI:
pom.xml or build.gradleapplication.properties@EnableMcpServer annotation@Component)@Tool(description = "...")@ToolParam to document parameters for AI understanding@PromptTemplate for reusable prompts@PromptParamPrompt objects with system and user messagesstdio, http, or sseapplication.ymlCreate a simple MCP server with function calling:
@SpringBootApplication
@EnableMcpServer
public class WeatherMcpApplication {
public static void main(String[] args) {
SpringApplication.run(WeatherMcpApplication.class, args);
}
}
@Component
public class WeatherTools {
@Tool(description = "Get current weather for a city")
public WeatherData getWeather(@ToolParam("City name") String city) {
// Implementation
return new WeatherData(city, "Sunny", 22.5);
}
}
Configure function calling in application.properties:
spring.ai.openai.api-key=${OPENAI_API_KEY}
spring.ai.mcp.enabled=true
spring.ai.mcp.transport=stdio
Add Spring AI MCP dependencies to your project:
Maven:
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-mcp-server</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId>
<version>1.0.0</version>
</dependency>
Gradle:
dependencies {
implementation 'org.springframework.ai:spring-ai-mcp-server:1.0.0'
implementation 'org.springframework.ai:spring-ai-starter-model-openai:1.0.0'
}
Or use Spring Boot starter:
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-mcp-starter</artifactId>
<version>1.0.0</version>
</dependency>
MCP standardizes AI application connections with Spring AI abstractions:
Tools : Executable functions using @Tool annotation
Resources : Data sources accessible via Spring components
Prompts : Template-based interactions with @PromptTemplate
Transport : Spring-managed communication channels
AI Application ←→ MCP Client ←→ Spring AI ←→ MCP Server ←→ Spring Services
@Tool: Declares methods as callable functions for AI models@ToolParam: Documents parameter purposes for AI understanding@PromptTemplate: Defines reusable prompt patterns@Model: Specifies AI model configurationsCreate tools with Spring AI's declarative approach:
@Component
public class DatabaseTools {
private final JdbcTemplate jdbcTemplate;
public DatabaseTools(JdbcTemplate jdbcTemplate) {
this.jdbcTemplate = jdbcTemplate;
}
@Tool(description = "Execute a safe read-only SQL query")
public List<Map<String, Object>> executeQuery(
@ToolParam("SQL SELECT query") String query,
@ToolParam(value = "Query parameters", required = false)
Map<String, Object> params) {
// Validate query is read-only
if (!query.trim().toUpperCase().startsWith("SELECT")) {
throw new IllegalArgumentException("Only SELECT queries are allowed");
}
return jdbcTemplate.queryForList(query, params);
}
@Tool(description = "Get table schema information")
public TableSchema getTableSchema(
@ToolParam("Table name") String tableName) {
String sql = "SELECT column_name, data_type " +
"FROM information_schema.columns " +
"WHERE table_name = ?";
List<Map<String, Object>> columns = jdbcTemplate.queryForList(sql, tableName);
return new TableSchema(tableName, columns);
}
}
record TableSchema(String tableName, List<Map<String, Object>> columns) {}
@Component
public class ApiTools {
private final WebClient webClient;
public ApiTools(WebClient.Builder webClientBuilder) {
this.webClient = webClientBuilder.build();
}
@Tool(description = "Make HTTP GET request to an API")
public ApiResponse callApi(
@ToolParam("API URL") String url,
@ToolParam(value = "Headers as JSON string", required = false)
String headersJson) {
// Validate URL
try {
new URL(url);
} catch (MalformedURLException e) {
throw new IllegalArgumentException("Invalid URL format");
}
// Parse headers if provided
HttpHeaders headers = new HttpHeaders();
if (headersJson != null && !headersJson.isBlank()) {
try {
Map<String, String> headersMap = new ObjectMapper()
.readValue(headersJson, Map.class);
headersMap.forEach(headers::add);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("Invalid headers JSON");
}
}
return webClient.get()
.uri(url)
.headers(h -> h.addAll(headers))
.retrieve()
.bodyToMono(ApiResponse.class)
.block();
}
}
record ApiResponse(int status, Map<String, Object> body, HttpHeaders headers) {}
Create reusable prompt templates with Spring AI:
@Component
public class CodeReviewPrompts {
@PromptTemplate(
name = "java-code-review",
description = "Review Java code for best practices and issues"
)
public Prompt createJavaCodeReviewPrompt(
@PromptParam("code") String code,
@PromptParam(value = "focusAreas", required = false)
List<String> focusAreas) {
String focus = focusAreas != null ?
String.join(", ", focusAreas) :
"general best practices";
return Prompt.builder()
.system("You are an expert Java code reviewer with 20 years of experience.")
.user("""
Review the following Java code for {focus}:
```java
{code}
```
Provide feedback in the following format:
1. Critical issues (must fix)
2. Warnings (should fix)
3. Suggestions (consider improving)
4. Positive aspects
Be specific and provide code examples where relevant.
""".replace("{code}", code).replace("{focus}", focus))
.build();
}
@PromptTemplate(
name = "generate-unit-tests",
description = "Generate comprehensive unit tests for Java code"
)
public Prompt createTestGenerationPrompt(
@PromptParam("code") String code,
@PromptParam("className") String className,
@PromptParam(value = "testingFramework", required = false)
String framework) {
String testFramework = framework != null ? framework : "JUnit 5";
return Prompt.builder()
.system("You are an expert in test-driven development.")
.user("""
Generate comprehensive unit tests for the following Java class using {testFramework}:
```java
{code}
```
Class: {className}
Requirements:
1. Test all public methods
2. Include edge cases and boundary conditions
3. Use appropriate assertions
4. Follow AAA pattern (Arrange, Act, Assert)
5. Include test method naming best practices
6. Mock external dependencies
""".replace("{code}", code)
.replace("{className}", className)
.replace("{testFramework}", testFramework))
.build();
}
}
Low-level function calling integration:
@Configuration
public class FunctionConfig {
@Bean
public FunctionCallback weatherFunction() {
return FunctionCallback.builder()
.function("getCurrentWeather", new WeatherService())
.description("Get the current weather for a location")
.inputType(WeatherRequest.class)
.build();
}
@Bean
public FunctionCallback calculatorFunction() {
return FunctionCallbackWrapper.builder(new Calculator())
.withName("calculate")
.withDescription("Perform mathematical calculations")
.build();
}
}
class WeatherService implements Function<WeatherRequest, WeatherResponse> {
@Override
public WeatherResponse apply(WeatherRequest request) {
// Call weather API
return new WeatherResponse(request.location(), 72, "Sunny");
}
}
record WeatherRequest(String location) {}
record WeatherResponse(String location, double temperature, String condition) {}
class Calculator implements BiFunction<String, Map<String, Object>, String> {
@Override
public String apply(String functionName, Map<String, Object> args) {
// Perform calculation based on args
return "result";
}
}
Set up MCP server with Spring Boot auto-configuration:
@Configuration
@AutoConfigureAfter({WebMvcAutoConfiguration.class})
@ConditionalOnClass({McpServer.class, ChatModel.class})
@ConditionalOnProperty(name = "spring.ai.mcp.enabled", havingValue = "true", matchIfMissing = true)
public class McpAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public McpServerProperties mcpServerProperties() {
return new McpServerProperties();
}
@Bean
@ConditionalOnMissingBean
public McpServer mcpServer(
List<FunctionCallback> functionCallbacks,
List<PromptTemplate> promptTemplates,
McpServerProperties properties
) {
McpServer.Builder builder = McpServer.builder()
.serverInfo("spring-ai-mcp", "1.0.0")
.transport(properties.getTransport().create());
// Register function callbacks as tools
functionCallbacks.forEach(callback ->
builder.tool(Tool.fromFunctionCallback(callback))
);
// Register prompt templates
promptTemplates.forEach(template ->
builder.prompt(Prompt.fromTemplate(template))
);
return builder.build();
}
@Bean
@ConditionalOnProperty(name = "spring.ai.mcp.actuator.enabled", havingValue = "true")
public McpHealthIndicator mcpHealthIndicator(McpServer mcpServer) {
return new McpHealthIndicator(mcpServer);
}
}
@ConfigurationProperties(prefix = "spring.ai.mcp")
public class McpServerProperties {
private boolean enabled = true;
private TransportConfig transport = new TransportConfig();
private ActuatorConfig actuator = new ActuatorConfig();
// Getters and setters
public static class TransportConfig {
private TransportType type = TransportType.STDIO;
private HttpConfig http = new HttpConfig();
public Transport create() {
return switch (type) {
case STDIO -> new StdioTransport();
case HTTP -> new HttpTransport(http.getPort());
case SSE -> new SseTransport(http.getPort(), http.getPath());
};
}
}
public static class HttpConfig {
private int port = 8080;
private String path = "/mcp";
// Getters and setters
}
public static class ActuatorConfig {
private boolean enabled = true;
// Getters and setters
}
public enum TransportType {
STDIO, HTTP, SSE
}
}
Configure MCP server in application.yml:
spring:
ai:
mcp:
enabled: true
transport:
type: stdio # Options: stdio, http, sse
http:
port: 8080
path: /mcp
actuator:
enabled: true
tools:
package-scan: com.example.tools
prompts:
package-scan: com.example.prompts
security:
enabled: true
allowed-tools:
- getWeather
- executeQuery
admin-tools:
- admin_*
For advanced configuration:
@Configuration
public class CustomMcpConfig {
@Bean
public McpServerCustomizer mcpServerCustomizer() {
return server -> {
server.addToolInterceptor((tool, args, chain) -> {
log.info("Executing tool: {}", tool.name());
long start = System.currentTimeMillis();
Object result = chain.execute(tool, args);
long duration = System.currentTimeMillis() - start;
log.info("Tool {} executed in {}ms", tool.name(), duration);
metrics.recordToolExecution(tool.name(), duration);
return result;
});
};
}
@Bean
public ToolFilter toolFilter(SecurityService securityService) {
return (tool, context) -> {
User user = securityService.getCurrentUser();
if (tool.name().startsWith("admin_")) {
return user.hasRole("ADMIN");
}
return securityService.isToolAllowed(user, tool.name());
};
}
}
@Service
public class SecurityService {
public boolean isToolAllowed(User user, String toolName) {
// Implement tool access control logic
return true;
}
}
Implement secure tool execution with Spring Security:
@Component
public class SecureToolExecutor {
private final McpServer mcpServer;
private final SecurityContextHolder strategy;
public SecureToolExecutor(McpServer mcpServer, SecurityContextHolder strategy) {
this.mcpServer = mcpServer;
this.strategy = strategy;
}
public ToolResult executeTool(String toolName, Map<String, Object> arguments) {
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (!(auth instanceof UserAuthentication userAuth)) {
throw new AccessDeniedException("User not authenticated");
}
// Check tool permissions
if (!hasToolPermission(userAuth.getUser(), toolName)) {
throw new AccessDeniedException("Tool not allowed: " + toolName);
}
// Validate arguments against injection patterns
validateArguments(arguments);
// Execute with audit logging
logToolExecution(userAuth.getUser(), toolName, arguments);
try {
ToolResult result = mcpServer.executeTool(toolName, arguments);
logToolSuccess(userAuth.getUser(), toolName);
return result;
} catch (Exception e) {
logToolFailure(userAuth.getUser(), toolName, e);
throw new ToolExecutionException("Tool execution failed", e);
}
}
private boolean hasToolPermission(User user, String toolName) {
// Implement permission logic based on user roles and tool sensitivity
return user.getAuthorities().stream()
.anyMatch(auth -> auth.getAuthority().equals("TOOL_" + toolName) ||
auth.getAuthority().equals("ROLE_ADMIN"));
}
private void validateArguments(Map<String, Object> arguments) {
// Implement argument validation to prevent injection attacks
arguments.forEach((key, value) -> {
if (value instanceof String str) {
if (str.contains(";") || str.contains("--")) {
throw new IllegalArgumentException("Invalid characters in argument: " + key);
}
}
});
}
private void logToolExecution(User user, String toolName, Map<String, Object> arguments) {
// Implement audit logging
}
private void logToolSuccess(User user, String toolName) {
// Log successful execution
}
private void logToolFailure(User user, String toolName, Exception e) {
// Log failed execution
}
}
class ToolExecutionException extends RuntimeException {
public ToolExecutionException(String message, Throwable cause) {
super(message, cause);
}
}
Use Spring's validation framework:
@Component
public class ValidatedTools {
@Tool(description = "Process user data with validation")
@Validated
public ProcessingResult processUserData(
@ToolParam("User data to process") @Valid UserData data) {
// Implementation
return new ProcessingResult("success", data);
}
}
record UserData(
@NotBlank(message = "Name is required")
@Size(max = 100, message = "Name must be 100 characters or less")
String name,
@NotNull(message = "Age is required")
@Min(value = 18, message = "Must be 18 or older")
@Max(value = 120, message = "Age must be realistic")
Integer age,
@NotBlank(message = "Email is required")
@Email(message = "Invalid email format")
String email
) {}
// Custom validator for sensitive operations
@Component
public class SensitiveOperationValidator {
public void validateOperation(String operation, User user, Map<String, Object> params) {
if (isSensitiveOperation(operation)) {
requireAdditionalAuthentication(user);
validateOperationLimits(user, operation);
logSensitiveOperation(user, operation, params);
}
}
private boolean isSensitiveOperation(String operation) {
return operation.startsWith("delete") || operation.startsWith("update");
}
private void requireAdditionalAuthentication(User user) {
// Implement MFA or re-authentication
}
private void validateOperationLimits(User user, String operation) {
// Check rate limits and quotas
}
private void logSensitiveOperation(User user, String operation, Map<String, Object> params) {
// Secure audit logging
}
}
Implement comprehensive error handling:
@ControllerAdvice
public class McpExceptionHandler {
@ExceptionHandler(ToolExecutionException.class)
public ResponseEntity<ErrorResponse> handleToolExecutionException(
ToolExecutionException ex, WebRequest request) {
ErrorResponse error = ErrorResponse.builder()
.timestamp(LocalDateTime.now())
.status(HttpStatus.INTERNAL_SERVER_ERROR.value())
.error("Tool Execution Failed")
.message(ex.getMessage())
.path(((ServletWebRequest) request).getRequest().getRequestURI())
.build();
log.error("Tool execution failed: {}", ex.getMessage(), ex);
return new ResponseEntity<>(error, HttpStatus.INTERNAL_SERVER_ERROR);
}
@ExceptionHandler(AccessDeniedException.class)
public ResponseEntity<ErrorResponse> handleAccessDenied(
AccessDeniedException ex, WebRequest request) {
ErrorResponse error = ErrorResponse.builder()
.timestamp(LocalDateTime.now())
.status(HttpStatus.FORBIDDEN.value())
.error("Access Denied")
.message("You do not have permission to execute this tool")
.path(((ServletWebRequest) request).getRequest().getRequestURI())
.build();
log.warn("Access denied: {}", ex.getMessage());
return new ResponseEntity<>(error, HttpStatus.FORBIDDEN);
}
@ExceptionHandler(IllegalArgumentException.class)
public ResponseEntity<ErrorResponse> handleValidationError(
IllegalArgumentException ex, WebRequest request) {
ErrorResponse error = ErrorResponse.builder()
.timestamp(LocalDateTime.now())
.status(HttpStatus.BAD_REQUEST.value())
.error("Validation Error")
.message(ex.getMessage())
.path(((ServletWebRequest) request).getRequest().getRequestURI())
.build();
return new ResponseEntity<>(error, HttpStatus.BAD_REQUEST);
}
@ExceptionHandler(Exception.class)
public ResponseEntity<ErrorResponse> handleGenericException(
Exception ex, WebRequest request) {
ErrorResponse error = ErrorResponse.builder()
.timestamp(LocalDateTime.now())
.status(HttpStatus.INTERNAL_SERVER_ERROR.value())
.error("Internal Server Error")
.message("An unexpected error occurred")
.path(((ServletWebRequest) request).getRequest().getRequestURI())
.build();
log.error("Unexpected error: {}", ex.getMessage(), ex);
return new ResponseEntity<>(error, HttpStatus.INTERNAL_SERVER_ERROR);
}
@Data
@Builder
static class ErrorResponse {
private LocalDateTime timestamp;
private int status;
private String error;
private String message;
private String path;
}
}
Register tools at runtime:
public class DynamicToolRegistry {
private final McpServer mcpServer;
private final Map<String, ToolRegistration> registeredTools = new ConcurrentHashMap<>();
public DynamicToolRegistry(McpServer mcpServer) {
this.mcpServer = mcpServer;
}
public void registerTool(ToolRegistration registration) {
registeredTools.put(registration.getId(), registration);
Tool tool = Tool.builder()
.name(registration.getName())
.description(registration.getDescription())
.inputSchema(registration.getInputSchema())
.function(args -> executeDynamicTool(registration.getId(), args))
.build();
mcpServer.addTool(tool);
log.info("Registered dynamic tool: {}", registration.getName());
}
public void unregisterTool(String toolId) {
ToolRegistration registration = registeredTools.remove(toolId);
if (registration != null) {
mcpServer.removeTool(registration.getName());
log.info("Unregistered dynamic tool: {}", registration.getName());
}
}
private Object executeDynamicTool(String toolId, Map<String, Object> args) {
ToolRegistration registration = registeredTools.get(toolId);
if (registration == null) {
throw new IllegalStateException("Tool not found: " + toolId);
}
// Execute based on registration type
return switch (registration.getType()) {
case GROOVY_SCRIPT -> executeGroovyScript(registration, args);
case SPRING_BEAN -> executeSpringBeanMethod(registration, args);
case HTTP_ENDPOINT -> callHttpEndpoint(registration, args);
};
}
private Object executeGroovyScript(ToolRegistration registration, Map<String, Object> args) {
// Implement Groovy script execution
return null;
}
private Object executeSpringBeanMethod(ToolRegistration registration, Map<String, Object> args) {
// Implement Spring bean method invocation
return null;
}
private Object callHttpEndpoint(ToolRegistration registration, Map<String, Object> args) {
// Implement HTTP call
return null;
}
}
@Data
@Builder
class ToolRegistration {
private String id;
private String name;
private String description;
private Map<String, Object> inputSchema;
private ToolType type;
private String target; // script, bean name, or URL
private Map<String, String> metadata;
}
enum ToolType {
GROOVY_SCRIPT,
SPRING_BEAN,
HTTP_ENDPOINT
}
Support multiple AI models:
@Configuration
public class MultiModelConfig {
@Bean
@Primary
public ChatModel primaryChatModel(@Value("${spring.ai.primary.model}") String modelName) {
return switch (modelName) {
case "gpt-4" -> new OpenAiChatModel(OpenAiApi.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build());
case "claude" -> new AnthropicChatModel(AnthropicApi.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.build());
default -> throw new IllegalArgumentException("Unsupported model: " + modelName);
};
}
@Bean
public Map<String, ChatModel> allChatModels() {
Map<String, ChatModel> models = new HashMap<>();
models.put("gpt-4", new OpenAiChatModel(OpenAiApi.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build()));
models.put("gpt-3.5", new OpenAiChatModel(OpenAiApi.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.model("gpt-3.5-turbo")
.build()));
models.put("claude-opus", new AnthropicChatModel(AnthropicApi.builder()
.apiKey(System.getenv("ANTHROPIC_API_KEY"))
.model("claude-3-opus-20240229")
.build()));
return models;
}
@Bean
public ModelSelector modelSelector(Map<String, ChatModel> models) {
return new SpringAiModelSelector(models);
}
}
@Component
public class SpringAiModelSelector implements ModelSelector {
private final Map<String, ChatModel> models;
public SpringAiModelSelector(Map<String, ChatModel> models) {
this.models = models;
}
@Override
public ChatModel selectModel(Prompt prompt, Map<String, Object> context) {
// Select model based on prompt complexity, cost, latency requirements
String modelName = determineBestModel(prompt, context);
return models.get(modelName);
}
private String determineBestModel(Prompt prompt, Map<String, Object> context) {
// Implement model selection logic
// Consider: prompt length, complexity, cost constraints, latency requirements
return "gpt-4";
}
}
Implement caching for tools and prompts:
@Configuration
@EnableCaching
public class McpCacheConfig {
@Bean
public CacheManager cacheManager() {
return new ConcurrentMapCacheManager(
"tool-results",
"prompt-templates",
"function-callbacks"
);
}
}
@Component
public class CachedToolExecutor {
private final McpServer mcpServer;
public CachedToolExecutor(McpServer mcpServer) {
this.mcpServer = mcpServer;
}
@Cacheable(
value = "tool-results",
key = "#toolName + '_' + #args.hashCode()",
unless = "#result.isCacheable() == false"
)
public ToolResult executeTool(String toolName, Map<String, Object> args) {
return mcpServer.executeTool(toolName, args);
}
@CacheEvict(value = "tool-results", allEntries = true)
public void clearToolCache() {
// Clear cache when tools are updated
}
@Cacheable(value = "prompt-templates", key = "#templateName")
public PromptTemplate getPromptTemplate(String templateName) {
return mcpServer.getPromptTemplate(templateName);
}
}
@SpringBootTest
class DatabaseToolsTest {
@Autowired
private DatabaseTools databaseTools;
@MockBean
private JdbcTemplate jdbcTemplate;
@Test
void testExecuteQuery_Success() {
// Given
String query = "SELECT * FROM users WHERE id = ?";
Map<String, Object> params = Map.of("id", 1);
List<Map<String, Object>> expectedResults = List.of(
Map.of("id", 1, "name", "John")
);
when(jdbcTemplate.queryForList(anyString(), anyMap()))
.thenReturn(expectedResults);
// When
List<Map<String, Object>> results = databaseTools.executeQuery(query, params);
// Then
assertThat(results).isEqualTo(expectedResults);
verify(jdbcTemplate).queryForList(query, params);
}
@Test
void testExecuteQuery_InvalidQuery_ThrowsException() {
// Given
String query = "DROP TABLE users";
// When & Then
assertThatThrownBy(() -> databaseTools.executeQuery(query, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Only SELECT queries are allowed");
verifyNoInteractions(jdbcTemplate);
}
@Test
void testGetTableSchema_Success() {
// Given
String tableName = "users";
List<Map<String, Object>> columns = List.of(
Map.of("column_name", "id", "data_type", "integer"),
Map.of("column_name", "name", "data_type", "varchar")
);
when(jdbcTemplate.queryForList(anyString(), eq(tableName)))
.thenReturn(columns);
// When
TableSchema schema = databaseTools.getTableSchema(tableName);
// Then
assertThat(schema.tableName()).isEqualTo(tableName);
assertThat(schema.columns()).isEqualTo(columns);
}
}
@SpringBootTest
@AutoConfigureMockMvc
class McpServerIntegrationTest {
@Autowired
private MockMvc mockMvc;
@Autowired
private McpServer mcpServer;
@MockBean
private DatabaseTools databaseTools;
@Test
void testExecuteTool_Success() throws Exception {
// Given
String toolName = "executeQuery";
Map<String, Object> args = Map.of(
"query", "SELECT * FROM users",
"params", Map.of()
);
List<Map<String, Object>> expectedResult = List.of(
Map.of("id", 1, "name", "Test User")
);
when(databaseTools.executeQuery(anyString(), anyMap()))
.thenReturn(expectedResult);
// When & Then
mockMvc.perform(post("/mcp/tools/executeQuery")
.contentType(MediaType.APPLICATION_JSON)
.content(new ObjectMapper().writeValueAsString(args)))
.andExpect(status().isOk())
.andExpect(jsonPath("$.result").isArray())
.andExpect(jsonPath("$.result[0].id").value(1));
}
@Test
void testListTools_Success() throws Exception {
// When & Then
mockMvc.perform(get("/mcp/tools"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.tools").isArray());
}
@Test
void testHealthEndpoint() throws Exception {
// When & Then
mockMvc.perform(get("/actuator/health/mcp"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.status").value("UP"));
}
}
@SpringBootTest
@Testcontainers
@AutoConfigureMockMvc
class McpServerIntegrationTest {
@Container
static PostgreSQLContainer<?> postgres = new PostgreSQLContainer<>("postgres:15")
.withDatabaseName("testdb")
.withUsername("test")
.withPassword("test");
@DynamicPropertySource
static void properties(DynamicPropertyRegistry registry) {
registry.add("spring.datasource.url", postgres::getJdbcUrl);
registry.add("spring.datasource.username", postgres::getUsername);
registry.add("spring.datasource.password", postgres::getPassword);
}
@Autowired
private MockMvc mockMvc;
@Test
void testDatabaseToolWithRealDatabase() throws Exception {
// Given
String query = "SELECT current_database(), current_user";
Map<String, Object> request = Map.of(
"tool", "executeQuery",
"arguments", Map.of("query", query)
);
// When & Then
mockMvc.perform(post("/mcp/tools/executeQuery")
.contentType(MediaType.APPLICATION_JSON)
.content(new ObjectMapper().writeValueAsString(request)))
.andExpect(status().isOk())
.andExpect(jsonPath("$.success").value(true))
.andExpect(jsonPath("$.data[0].current_database").value("testdb"))
.andExpect(jsonPath("$.data[0].current_user").value("test"));
}
}
@WebMvcTest (Slice Test)@WebMvcTest(controllers = McpController.class)
class McpControllerSliceTest {
@Autowired
private MockMvc mockMvc;
@MockBean
private McpServer mcpServer;
@MockBean
private ToolRegistry toolRegistry;
@Test
void testListToolsEndpoint() throws Exception {
// Given
Tool tool1 = Tool.builder().name("tool1").description("Tool 1").build();
Tool tool2 = Tool.builder().name("tool2").description("Tool 2").build();
when(toolRegistry.listTools()).thenReturn(List.of(tool1, tool2));
// When & Then
mockMvc.perform(get("/mcp/tools"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.tools").isArray())
.andExpect(jsonPath("$.tools.length()").value(2))
.andExpect(jsonPath("$.tools[0].name").value("tool1"));
}
}
@ExtendWith(MockitoExtension.class)
class ToolValidationTest {
private ToolValidator validator;
@BeforeEach
void setUp() {
McpServerProperties properties = new McpServerProperties();
properties.getTools().getValidation().setMaxArgumentsSize(1000);
validator = new DefaultToolValidator(properties);
}
@Test
void testValidArguments() {
// Given
Tool tool = Tool.builder()
.name("testTool")
.method(getTestMethod())
.build();
Map<String, Object> args = Map.of("param1", "value1", "param2", 123);
// When & Then
assertDoesNotThrow(() -> validator.validateArguments(tool, args));
}
@Test
void testArgumentsTooLarge() {
// Given
Tool tool = Tool.builder().name("testTool").build();
Map<String, Object> args = Map.of("largeParam", "x".repeat(2000));
// When & Then
ValidationException exception = assertThrows(
ValidationException.class,
() -> validator.validateArguments(tool, args)
);
assertThat(exception.getMessage()).contains("Arguments too large");
}
}
@SpringBootTest
@AutoConfigureMockMvc
@WithMockUser(roles = {"USER"})
class McpSecurityTest {
@Autowired
private MockMvc mockMvc;
@Test
void testUserCanAccessRegularTools() throws Exception {
mockMvc.perform(get("/mcp/tools/getWeather"))
.andExpect(status().isOk());
}
@Test
@WithMockUser(roles = {"USER"})
void testUserCannotAccessAdminTools() throws Exception {
mockMvc.perform(get("/mcp/tools/admin/deleteData"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(roles = {"ADMIN"})
void testAdminCanAccessAllTools() throws Exception {
mockMvc.perform(get("/mcp/tools/admin/deleteData"))
.andExpect(status().isOk());
}
}
@SpringBootTest
@EnableConfigurationProperties(McpServerProperties.class)
class McpPropertiesTest {
@Autowired
private McpServerProperties properties;
@Test
void testDefaultValues() {
assertThat(properties.getServer().getName()).isEqualTo("spring-ai-mcp-server");
assertThat(properties.getTransport().getType()).isEqualTo(TransportType.STDIO);
assertThat(properties.getSecurity().isEnabled()).isFalse();
}
}
# Spring AI Configuration
spring.ai.openai.api-key=${OPENAI_API_KEY}
spring.ai.openai.chat.options.model=gpt-4o-mini
spring.ai.openai.chat.options.temperature=0.7
# MCP Server Configuration
spring.ai.mcp.enabled=true
spring.ai.mcp.server.name=spring-ai-mcp-server
spring.ai.mcp.server.version=1.0.0
spring.ai.mcp.transport.type=stdio
# HTTP Transport (if enabled)
spring.ai.mcp.transport.http.port=8080
spring.ai.mcp.transport.http.path=/mcp
spring.ai.mcp.transport.http.cors.enabled=true
spring.ai.mcp.transport.http.cors.allowed-origins=*
# Security Configuration
spring.ai.mcp.security.enabled=true
spring.ai.mcp.security.authorization.mode=role-based
spring.ai.mcp.security.authorization.default-deny=true
spring.ai.mcp.security.audit.enabled=true
# Tool Configuration
spring.ai.mcp.tools.package-scan=com.example.mcp.tools
spring.ai.mcp.tools.validation.enabled=true
spring.ai.mcp.tools.validation.max-execution-time=30s
spring.ai.mcp.tools.caching.enabled=true
spring.ai.mcp.tools.caching.ttl=5m
# Prompt Configuration
spring.ai.mcp.prompts.package-scan=com.example.mcp.prompts
spring.ai.mcp.prompts.caching.enabled=true
spring.ai.mcp.prompts.caching.ttl=1h
# Actuator and Monitoring
spring.ai.mcp.actuator.enabled=true
spring.ai.mcp.metrics.enabled=true
spring.ai.mcp.metrics.export.prometheus.enabled=true
spring.ai.mcp.logging.enabled=true
spring.ai.mcp.logging.level=DEBUG
# Performance Tuning
spring.ai.mcp.thread-pool.core-size=10
spring.ai.mcp.thread-pool.max-size=50
spring.ai.mcp.thread-pool.queue-capacity=100
spring.ai.mcp.rate-limiter.enabled=true
spring.ai.mcp.rate-limiter.requests-per-minute=100
Use Declarative Annotations : Prefer @Tool and @PromptTemplate over manual registration for cleaner, more maintainable code.
Keep Tools Focused : Each tool should do one thing well. Avoid creating monolithic tools that handle multiple unrelated operations.
Use Descriptive Names : Tool names should clearly indicate what they do. Use verbs like get, create, update, delete, search for actions.
Document Parameters : Use @ToolParam with clear descriptions so AI models understand when and how to use each parameter.
Implement Input Validation : Always validate user inputs to prevent injection attacks. Never trust AI-generated parameters without validation.
Use Authorization : Implement role-based access control for sensitive operations. Use Spring Security annotations like @PreAuthorize.
Sanitize Error Messages : Never expose sensitive information in error messages or tool descriptions.
Audit Sensitive Operations : Log all executions of tools that modify data or access sensitive resources.
Rate Limiting : Implement rate limiting for expensive operations to prevent abuse and ensure fair usage.
Use Caching : Cache results of expensive operations using Spring Cache with @Cacheable. Set appropriate TTL values.
Implement Timeouts : Always set timeouts for external API calls to prevent hanging requests.
Use Async Processing : For long-running operations, consider using @Async to return immediately and provide status updates.
Connection Pooling : Configure proper connection pools for database and HTTP clients.
Monitor Performance : Track tool execution times and success rates using Micrometer metrics.
Handle Errors Gracefully : Implement proper exception handling with user-friendly error messages.
Use Specific Exceptions : Create custom exceptions for different error scenarios to enable proper error handling.
Implement Fallback : Provide fallback values or retry logic for transient failures.
Log Context : Include relevant context (user, tool name, parameters) in error logs for debugging.
Write Unit Tests : Test each tool independently with mocked dependencies.
Write Integration Tests : Test the entire MCP server flow including transport and serialization.
Test Security : Verify that authorization rules are properly enforced.
Use Testcontainers : For database-dependent tools, use Testcontainers for realistic testing.
Test Edge Cases : Test with invalid inputs, null values, and boundary conditions.
Document Tool Purpose : Clearly describe what each tool does and when to use it.
Provide Examples : Include usage examples in the tool description or separate documentation.
Version Your API : Maintain backward compatibility when updating tools. Use semantic versioning.
Document Return Types : Clearly describe the structure of return values.
Externalize Configuration : Use application.yml for configurable values like URLs, timeouts, and credentials.
Use Profiles : Define different configurations for dev, test, and production environments.
Validate Configuration : Use @ConfigurationProperties with validation for type-safe configuration.
Leverage Dependency Injection : Use constructor injection for dependencies to improve testability.
Use Qualifiers : When multiple beans of the same type exist, use @Qualifier to disambiguate.
Implement Health Checks : Provide health indicators to monitor MCP server status.
Use Profiles : Different configurations for development and production environments.
Design for Idempotency : Tools should be idempotent when possible as AI models may retry failed calls.
Be Conservative with Data Modifying Tools : AI models may call modifying tools unexpectedly - consider adding confirmation steps.
Provide Context : Include relevant context in responses to help AI models understand results.
Handle Large Responses : For tools that return large data, consider pagination or summary options.
If migrating from LangChain4j MCP to Spring AI MCP:
@Tool instead of LangChain4j's @ToolMethodFunctionCallback for low-level control@ToolMethod with Spring AI @Toolapplication.properties to Spring AI properties@Component// Before: LangChain4j
@ToolMethod("Get weather information")
public String getWeather(@P("city name") String city) {
return weatherService.getWeather(city);
}
// After: Spring AI
@Component
public class WeatherTools {
@Tool(description = "Get weather information")
public String getWeather(@ToolParam("City name") String city) {
return weatherService.getWeather(city);
}
}
A complete, production-ready weather MCP server:
// Application.java
@SpringBootApplication
@EnableMcpServer
public class WeatherMcpServerApplication {
public static void main(String[] args) {
SpringApplication.run(WeatherMcpServerApplication.class, args);
}
}
// WeatherTools.java
@Component
@Slf4j
public class WeatherTools {
private final WeatherService weatherService;
private final MeterRegistry meterRegistry;
public WeatherTools(WeatherService weatherService, MeterRegistry meterRegistry) {
this.weatherService = weatherService;
this.meterRegistry = meterRegistry;
}
@Tool(description = "Get current weather conditions for a city")
public WeatherResponse getCurrentWeather(
@ToolParam("City name (e.g., 'New York, NY')") String city) {
Timer.Sample sample = Timer.start(meterRegistry);
try {
log.info("Fetching weather for: {}", city);
WeatherResponse response = weatherService.getCurrentWeather(city);
sample.stop(Timer.builder("weather.tool.duration")
.tag("city", city)
.register(meterRegistry));
return response;
} catch (Exception e) {
log.error("Error fetching weather for {}", city, e);
throw new ToolExecutionException("Unable to fetch weather data", e);
}
}
@Tool(description = "Get 5-day weather forecast for a city")
public ForecastResponse getForecast(
@ToolParam("City name") String city,
@ToolParam(value = "Temperature unit (celsius or fahrenheit)", required = false)
String unit) {
String tempUnit = unit != null ? unit : "celsius";
return weatherService.getForecast(city, tempUnit);
}
@Tool(description = "Get weather alerts for a location")
public List<WeatherAlert> getWeatherAlerts(
@ToolParam("State or province code (e.g., 'CA', 'NY')") String stateCode) {
return weatherService.getAlerts(stateCode);
}
}
// WeatherService.java
@Service
public class WeatherService {
private final WebClient webClient;
private final WeatherCache cache;
public WeatherService(WebClient.Builder webClientBuilder,
WeatherCache cache) {
this.webClient = webClientBuilder
.baseUrl("https://api.weather.gov")
.build();
this.cache = cache;
}
@Cacheable(value = "weather", key = "#city")
public WeatherResponse getCurrentWeather(String city) {
return webClient.get()
.uri("/points/{city}/forecast", city)
.retrieve()
.bodyToMono(WeatherResponse.class)
.block();
}
}
@Component
@PreAuthorize("hasRole('USER')")
public class DatabaseTools {
private final JdbcTemplate jdbcTemplate;
private final QueryValidator validator;
public DatabaseTools(JdbcTemplate jdbcTemplate, QueryValidator validator) {
this.jdbcTemplate = jdbcTemplate;
this.validator = validator;
}
@Tool(description = "Execute a read-only SQL query and return results")
public QueryResult executeQuery(
@ToolParam("SQL SELECT query to execute") String sql,
@ToolParam(value = "Query parameters as JSON map", required = false)
String paramsJson) {
// Validate query is read-only
validator.validateReadOnly(sql);
// Parse parameters
Map<String, Object> params = parseParams(paramsJson);
// Execute with timeout
return jdbcTemplate.query(sql, params, rs -> {
ResultSetMetaData metaData = rs.getMetaData();
int columnCount = metaData.getColumnCount();
List<Map<String, Object>> rows = new ArrayList<>();
while (rs.next()) {
Map<String, Object> row = new LinkedHashMap<>();
for (int i = 1; i <= columnCount; i++) {
row.put(metaData.getColumnName(i), rs.getObject(i));
}
rows.add(row);
}
return new QueryResult(rows, rows.size());
});
}
@Tool(description = "Get table metadata including columns and types")
public TableMetadata describeTable(
@ToolParam("Table name") String tableName) {
String sql = """
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = ?
ORDER BY ordinal_position
""";
List<ColumnInfo> columns = jdbcTemplate.query(sql,
(rs, rowNum) -> new ColumnInfo(
rs.getString("column_name"),
rs.getString("data_type"),
rs.getString("is_nullable").equals("YES"),
rs.getString("column_default")
),
tableName
);
return new TableMetadata(tableName, columns);
}
}
@Component
@Slf4j
public class FileSystemTools {
private final Path basePath;
private final FileSecurityFilter securityFilter;
public FileSystemTools(
@Value("${mcp.file.base-path:/tmp}") String basePath,
FileSecurityFilter securityFilter) {
this.basePath = Paths.get(basePath).normalize();
this.securityFilter = securityFilter;
}
@Tool(description = "List files in a directory")
public List<FileInfo> listFiles(
@ToolParam(value = "Directory path (relative to base)", required = false)
String directory) {
Path targetPath = resolvePath(directory != null ? directory : "");
securityFilter.validatePath(targetPath, basePath);
try (Stream<Path> stream = Files.list(targetPath)) {
return stream
.filter(Files::isRegularFile)
.map(this::toFileInfo)
.toList();
} catch (IOException e) {
throw new ToolExecutionException("Failed to list files", e);
}
}
@Tool(description = "Read file contents")
public FileContent readFile(
@ToolParam("File path (relative to base)") String filePath,
@ToolParam(value = "Maximum lines to read", required = false)
Integer maxLines) {
Path targetPath = resolvePath(filePath);
securityFilter.validatePath(targetPath, basePath);
try {
List<String> lines = maxLines != null
? Files.lines(targetPath).limit(maxLines).toList()
: Files.readAllLines(targetPath);
return new FileContent(targetPath.toString(), lines);
} catch (IOException e) {
throw new ToolExecutionException("Failed to read file", e);
}
}
@Tool(description = "Search for files matching a pattern")
public List<FileInfo> searchFiles(
@ToolParam("Search pattern (glob, e.g., '*.txt')") String pattern,
@ToolParam(value = "Directory to search in", required = false)
String directory) {
Path targetPath = resolvePath(directory != null ? directory : "");
securityFilter.validatePath(targetPath, basePath);
try (Stream<Path> stream = Files.walk(targetPath, 10)) {
PathMatcher matcher = FileSystems.getDefault()
.getPathMatcher("glob:" + pattern);
return stream
.filter(Files::isRegularFile)
.filter(path -> matcher.matches(path.getFileName()))
.map(this::toFileInfo)
.toList();
} catch (IOException e) {
throw new ToolExecutionException("Search failed", e);
}
}
private Path resolvePath(String relativePath) {
return basePath.resolve(relativePath).normalize();
}
private FileInfo toFileInfo(Path path) {
try {
return new FileInfo(
basePath.relativize(path).toString(),
Files.size(path),
Files.getLastModifiedTime(path).toInstant()
);
} catch (IOException e) {
return new FileInfo(path.toString(), 0, Instant.now());
}
}
}
@Component
@Slf4j
public class ApiIntegrationTools {
private final WebClient webClient;
private final ApiCredentialManager credentialManager;
public ApiIntegrationTools(WebClient.Builder webClientBuilder,
ApiCredentialManager credentialManager) {
this.webClient = webClientBuilder.build();
this.credentialManager = credentialManager;
}
@Tool(description = "Call an external REST API endpoint")
public ApiResponse callExternalApi(
@ToolParam("HTTP method (GET, POST, PUT, DELETE)") String method,
@ToolParam("API URL") String url,
@ToolParam(value = "Request body (JSON)", required = false)
String body,
@ToolParam(value = "Headers (JSON map)", required = false)
String headersJson) {
// Validate URL
validateUrl(url);
// Get credentials
ApiCredentials credentials = credentialManager.getCredentials(url);
// Build request
WebClient.RequestBodySpec request = webClient
.method(HttpMethod.valueOf(method.toUpperCase()))
.uri(url);
// Add headers
addHeaders(request, headersJson, credentials);
// Add body if applicable
if (body != null && !body.isBlank()) {
request.contentType(MediaType.APPLICATION_JSON);
request.body(body);
}
// Execute
return request.retrieve()
.bodyToMono(ApiResponse.class)
.block();
}
@Tool(description = "Paginate through API results")
public List<ApiResponse> fetchPaginatedResults(
@ToolParam("Base API URL") String baseUrl,
@ToolParam(value = "Page size", required = false)
Integer pageSize) {
int size = pageSize != null ? pageSize : 100;
List<ApiResponse> allResults = new ArrayList<>();
int page = 0;
while (true) {
String url = baseUrl + "?page=" + page + "&size=" + size;
PaginatedResponse response = webClient.get()
.uri(url)
.retrieve()
.bodyToMono(PaginatedResponse.class)
.block();
allResults.addAll(response.getData());
if (!response.hasNext()) {
break;
}
page++;
}
return allResults;
}
}
@Component
public class CodeReviewPrompts {
@PromptTemplate(
name = "java-code-review",
description = "Review Java code for best practices and potential issues"
)
public Prompt createCodeReviewPrompt(
@PromptParam("Java code to review") String code,
@PromptParam(value = "Focus areas (comma-separated)", required = false)
String focusAreas) {
String focus = focusAreas != null ? focusAreas : "general best practices";
return Prompt.builder()
.system(createSystemPrompt())
.user(createUserPrompt(code, focus))
.build();
}
@PromptTemplate(
name = "refactor-suggestion",
description = "Suggest refactoring improvements for code"
)
public Prompt createRefactorPrompt(
@PromptParam("Code to refactor") String code,
@PromptParam("Refactoring goal") String goal) {
return Prompt.builder()
.system("You are an expert software architect specializing in code refactoring.")
.user("""
Analyze the following code and suggest refactoring to achieve: {goal}
```java
{code}
```
Provide:
1. Current issues identified
2. Suggested refactoring approach
3. Refactored code example
4. Benefits of the refactoring
""".replace("{code}", code).replace("{goal}", goal))
.build();
}
private String createSystemPrompt() {
return """
You are an expe
AI 代码实施计划编写技能 | 自动化开发任务分解与 TDD 流程规划工具
43,400 周安装
Return Structured Data : Use records or DTOs for return values instead of raw strings or maps to provide schema information to AI models.