Advanced Server Features
Explore powerful features that make MCP-Go servers production-ready: typed tools, session management, middleware, hooks, and more.
Typed Tools
Typed tools provide compile-time type safety and automatic parameter validation, reducing boilerplate and preventing runtime errors.
Basic Typed Tool
package main
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// Define input and output types
type CalculateInput struct {
Operation string `json:"operation" validate:"required,oneof=add subtract multiply divide"`
X float64 `json:"x" validate:"required"`
Y float64 `json:"y" validate:"required"`
}
type CalculateOutput struct {
Result float64 `json:"result"`
Operation string `json:"operation"`
}
func main() {
s := server.NewMCPServer("Typed Server", "1.0.0",
server.WithToolCapabilities(true),
)
// Create typed tool
tool := mcp.NewTool("calculate",
mcp.WithDescription("Perform arithmetic operations"),
mcp.WithString("operation", mcp.Required()),
mcp.WithNumber("x", mcp.Required()),
mcp.WithNumber("y", mcp.Required()),
)
// Add tool with typed handler
s.AddTool(tool, mcp.NewTypedToolHandler(handleCalculateTyped))
server.ServeStdio(s)
}
func handleCalculateTyped(ctx context.Context, req mcp.CallToolRequest, input CalculateInput) (*mcp.CallToolResult, error) {
var result float64
switch input.Operation {
case "add":
result = input.X + input.Y
case "subtract":
result = input.X - input.Y
case "multiply":
result = input.X * input.Y
case "divide":
if input.Y == 0 {
return mcp.NewToolResultError("division by zero"), nil
}
result = input.X / input.Y
}
output := CalculateOutput{
Result: result,
Operation: input.Operation,
}
jsonData, err := json.Marshal(output)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil
}
return mcp.NewToolResultText(string(jsonData)), nil
}
Complex Typed Tool
type UserCreateInput struct {
Name string `json:"name" validate:"required,min=1,max=100"`
Email string `json:"email" validate:"required,email"`
Age int `json:"age" validate:"min=0,max=150"`
Tags []string `json:"tags" validate:"dive,min=1"`
Metadata map[string]string `json:"metadata"`
Active bool `json:"active"`
}
type UserCreateOutput struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
CreatedAt time.Time `json:"created_at"`
Status string `json:"status"`
}
func handleCreateUser(ctx context.Context, req mcp.CallToolRequest, input UserCreateInput) (*mcp.CallToolResult, error) {
// Validation is automatic based on struct tags
// Create user in database
user := &User{
ID: generateID(),
Name: input.Name,
Email: input.Email,
Age: input.Age,
Tags: input.Tags,
Metadata: input.Metadata,
Active: input.Active,
CreatedAt: time.Now(),
}
if err := db.Create(user); err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
output := &UserCreateOutput{
ID: user.ID,
Name: user.Name,
Email: user.Email,
CreatedAt: user.CreatedAt,
Status: "created",
}
return mcp.NewToolResultJSON(output), nil
}
Custom Validation
import (
"path/filepath"
"strings"
"github.com/go-playground/validator/v10"
)
type FileOperationInput struct {
Path string `json:"path" validate:"required,filepath"`
Operation string `json:"operation" validate:"required,oneof=read write delete"`
Content string `json:"content" validate:"required_if=Operation write"`
}
// Custom validator
func init() {
validate := validator.New()
validate.RegisterValidation("filepath", validateFilePath)
}
func validateFilePath(fl validator.FieldLevel) bool {
path := fl.Field().String()
// Prevent directory traversal
if strings.Contains(path, "..") {
return false
}
// Ensure path is within allowed directory
allowedDir := "/app/data"
absPath, err := filepath.Abs(path)
if err != nil {
return false
}
return strings.HasPrefix(absPath, allowedDir)
}
Session Management
Handle multiple clients with per-session state and tools.
Per-Session State
type SessionState struct {
UserID string
Permissions []string
Settings map[string]interface{}
StartTime time.Time
}
type SessionManager struct {
sessions map[string]*SessionState
mutex sync.RWMutex
}
func NewSessionManager() *SessionManager {
return &SessionManager{
sessions: make(map[string]*SessionState),
}
}
func (sm *SessionManager) CreateSession(sessionID, userID string, permissions []string) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
sm.sessions[sessionID] = &SessionState{
UserID: userID,
Permissions: permissions,
Settings: make(map[string]interface{}),
StartTime: time.Now(),
}
}
func (sm *SessionManager) GetSession(sessionID string) (*SessionState, bool) {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
session, exists := sm.sessions[sessionID]
return session, exists
}
func (sm *SessionManager) RemoveSession(sessionID string) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
delete(sm.sessions, sessionID)
}
Session-Aware Tools
func main() {
sessionManager := NewSessionManager()
hooks := &server.Hooks{}
hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) {
// Initialize session with default permissions
sessionManager.CreateSession(session.ID(), "anonymous", []string{"read"})
log.Printf("Session %s started", session.ID())
})
hooks.AddOnUnregisterSession(func(ctx context.Context, session server.ClientSession) {
sessionManager.RemoveSession(session.ID())
log.Printf("Session %s ended", session.ID())
})
s := server.NewMCPServer("Session Server", "1.0.0",
server.WithToolCapabilities(true),
server.WithHooks(hooks),
)
// Add session-aware tool
s.AddTool(
mcp.NewTool("get_user_data",
mcp.WithDescription("Get user-specific data"),
mcp.WithString("data_type", mcp.Required()),
),
createSessionAwareTool(sessionManager),
)
server.ServeStdio(s)
}
func createSessionAwareTool(sm *SessionManager) server.ToolHandler {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
sessionID := server.GetSessionID(ctx)
session, exists := sm.GetSession(sessionID)
if !exists {
return nil, fmt.Errorf("invalid session")
}
dataType := req.Params.Arguments["data_type"].(string)
// Check permissions
if !hasPermission(session.Permissions, "read") {
return nil, fmt.Errorf("insufficient permissions")
}
// Get user-specific data
data, err := getUserData(session.UserID, dataType)
if err != nil {
return nil, err
}
jsonData, err := json.Marshal(data)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil
}
return mcp.NewToolResultText(string(jsonData)), nil
}
}
Middleware
Add cross-cutting concerns like logging, authentication, and rate limiting.
Logging Middleware
type LoggingMiddleware struct {
logger *log.Logger
}
func NewLoggingMiddleware(logger *log.Logger) *LoggingMiddleware {
return &LoggingMiddleware{logger: logger}
}
func (m *LoggingMiddleware) ToolMiddleware(next server.ToolHandlerFunc) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
start := time.Now()
sessionID := server.GetSessionID(ctx)
m.logger.Printf("Tool call started: tool=%s", req.Params.Name)
result, err := next(ctx, req)
duration := time.Since(start)
if err != nil {
m.logger.Printf("Tool call failed: session=%s tool=%s duration=%v error=%v",
sessionID, req.Params.Name, duration, err)
} else {
m.logger.Printf("Tool call completed: session=%s tool=%s duration=%v",
sessionID, req.Params.Name, duration)
}
return result, err
}
}
func (m *LoggingMiddleware) ResourceMiddleware(next server.ResourceHandler) server.ResourceHandler {
return func(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
start := time.Now()
sessionID := server.GetSessionID(ctx)
m.logger.Printf("Resource read started: session=%s uri=%s", sessionID, req.Params.URI)
result, err := next(ctx, req)
duration := time.Since(start)
if err != nil {
m.logger.Printf("Resource read failed: session=%s uri=%s duration=%v error=%v",
sessionID, req.Params.URI, duration, err)
} else {
m.logger.Printf("Resource read completed: session=%s uri=%s duration=%v",
sessionID, req.Params.URI, duration)
}
return result, err
}
}
Rate Limiting Middleware
type RateLimitMiddleware struct {
limiters map[string]*rate.Limiter
mutex sync.RWMutex
rate rate.Limit
burst int
}
func NewRateLimitMiddleware(requestsPerSecond float64, burst int) *RateLimitMiddleware {
return &RateLimitMiddleware{
limiters: make(map[string]*rate.Limiter),
rate: rate.Limit(requestsPerSecond),
burst: burst,
}
}
func (m *RateLimitMiddleware) getLimiter(sessionID string) *rate.Limiter {
m.mutex.RLock()
limiter, exists := m.limiters[sessionID]
m.mutex.RUnlock()
if !exists {
m.mutex.Lock()
limiter = rate.NewLimiter(m.rate, m.burst)
m.limiters[sessionID] = limiter
m.mutex.Unlock()
}
return limiter
}
func (m *RateLimitMiddleware) ToolMiddleware(next server.ToolHandler) server.ToolHandler {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
sessionID := server.GetSessionID(ctx)
limiter := m.getLimiter(sessionID)
if !limiter.Allow() {
return nil, fmt.Errorf("rate limit exceeded for session %s", sessionID)
}
return next(ctx, req)
}
}
Authentication Middleware
type AuthMiddleware struct {
tokenValidator TokenValidator
}
func NewAuthMiddleware(validator TokenValidator) *AuthMiddleware {
return &AuthMiddleware{tokenValidator: validator}
}
func (m *AuthMiddleware) ToolMiddleware(next server.ToolHandler) server.ToolHandler {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Extract token from context or request
token := extractToken(ctx, req)
if token == "" {
return nil, fmt.Errorf("authentication required")
}
// Validate token
user, err := m.tokenValidator.Validate(token)
if err != nil {
return nil, fmt.Errorf("invalid token: %w", err)
}
// Add user to context
ctx = context.WithValue(ctx, "user", user)
return next(ctx, req)
}
}
Hooks
Implement lifecycle callbacks for telemetry, logging, and custom behavior.
Comprehensive Hooks
type TelemetryHooks struct {
metrics MetricsCollector
logger *log.Logger
}
func NewTelemetryHooks(metrics MetricsCollector, logger *log.Logger) *TelemetryHooks {
return &TelemetryHooks{
metrics: metrics,
logger: logger,
}
}
func (h *TelemetryHooks) OnServerStart() {
h.logger.Println("MCP Server starting")
h.metrics.Increment("server.starts")
}
func (h *TelemetryHooks) OnServerStop() {
h.logger.Println("MCP Server stopping")
h.metrics.Increment("server.stops")
}
func (h *TelemetryHooks) OnSessionStart(sessionID string) {
h.logger.Printf("Session started: %s", sessionID)
h.metrics.Increment("sessions.started")
h.metrics.Gauge("sessions.active", h.getActiveSessionCount())
}
func (h *TelemetryHooks) OnSessionEnd(sessionID string) {
h.logger.Printf("Session ended: %s", sessionID)
h.metrics.Increment("sessions.ended")
h.metrics.Gauge("sessions.active", h.getActiveSessionCount())
}
func (h *TelemetryHooks) OnToolCall(sessionID, toolName string, duration time.Duration, err error) {
h.metrics.Increment("tools.calls", map[string]string{
"tool": toolName,
"session": sessionID,
})
h.metrics.Histogram("tools.duration", duration.Seconds(), map[string]string{
"tool": toolName,
})
if err != nil {
h.metrics.Increment("tools.errors", map[string]string{
"tool": toolName,
})
}
}
func (h *TelemetryHooks) OnResourceRead(sessionID, uri string, duration time.Duration, err error) {
h.metrics.Increment("resources.reads", map[string]string{
"session": sessionID,
})
h.metrics.Histogram("resources.duration", duration.Seconds())
if err != nil {
h.metrics.Increment("resources.errors")
}
}
Custom Business Logic Hooks
type BusinessHooks struct {
auditLogger AuditLogger
notifier Notifier
}
func (h *BusinessHooks) OnToolCall(sessionID, toolName string, duration time.Duration, err error) {
// Audit sensitive operations
if isSensitiveTool(toolName) {
h.auditLogger.LogToolCall(sessionID, toolName, err)
}
// Alert on errors
if err != nil {
h.notifier.SendAlert(fmt.Sprintf("Tool %s failed for session %s: %v",
toolName, sessionID, err))
}
// Monitor performance
if duration > 30*time.Second {
h.notifier.SendAlert(fmt.Sprintf("Slow tool execution: %s took %v",
toolName, duration))
}
}
func (h *BusinessHooks) OnSessionStart(sessionID string) {
// Initialize user-specific resources
h.initializeUserResources(sessionID)
// Send welcome notification
h.notifier.SendWelcome(sessionID)
}
func (h *BusinessHooks) OnSessionEnd(sessionID string) {
// Cleanup user resources
h.cleanupUserResources(sessionID)
// Log session summary
h.auditLogger.LogSessionEnd(sessionID)
}
Tool Filtering
Conditionally expose tools based on context, permissions, or other criteria.
Permission-Based Filtering
type PermissionFilter struct {
sessionManager *SessionManager
}
func NewPermissionFilter(sm *SessionManager) *PermissionFilter {
return &PermissionFilter{sessionManager: sm}
}
func (f *PermissionFilter) FilterTools(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
sessionID := server.GetSessionID(ctx)
session, exists := f.sessionManager.GetSession(sessionID)
if !exists {
return []mcp.Tool{} // No tools for invalid sessions
}
var filtered []mcp.Tool
for _, tool := range tools {
if f.hasPermissionForTool(session, tool.Name) {
filtered = append(filtered, tool)
}
}
return filtered
}
func (f *PermissionFilter) hasPermissionForTool(session *SessionState, toolName string) bool {
requiredPermissions := map[string][]string{
"delete_user": {"admin"},
"modify_system": {"admin", "operator"},
"read_data": {"admin", "operator", "user"},
"create_report": {"admin", "operator", "user"},
}
required, exists := requiredPermissions[toolName]
if !exists {
return true // Allow tools without specific requirements
}
for _, permission := range session.Permissions {
for _, req := range required {
if permission == req {
return true
}
}
}
return false
}
Context-Based Filtering
type ContextFilter struct{}
func (f *ContextFilter) FilterTools(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
timeOfDay := time.Now().Hour()
environment := os.Getenv("ENVIRONMENT")
var filtered []mcp.Tool
for _, tool := range tools {
if f.shouldIncludeTool(tool, timeOfDay, environment) {
filtered = append(filtered, tool)
}
}
return filtered
}
func (f *ContextFilter) shouldIncludeTool(tool mcp.Tool, hour int, env string) bool {
// Maintenance tools only during off-hours
maintenanceTools := map[string]bool{
"backup_database": true,
"cleanup_logs": true,
"restart_service": true,
}
if maintenanceTools[tool.Name] {
return hour < 6 || hour > 22 // Only between 10 PM and 6 AM
}
// Debug tools only in development
debugTools := map[string]bool{
"debug_session": true,
"dump_state": true,
}
if debugTools[tool.Name] {
return env == "development"
}
return true
}
Notifications
Send server-to-client messages for real-time updates.
Custom Notifications
func handleLongRunningTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
srv := server.ServerFromContext(ctx)
// Simulate long-running work
for i := 0; i < 100; i++ {
time.Sleep(100 * time.Millisecond)
// Send custom notification to all clients
notification := map[string]interface{}{
"type": "progress",
"progress": i + 1,
"total": 100,
"message": fmt.Sprintf("Processing step %d/100", i+1),
}
srv.SendNotificationToAllClients("progress", notification)
}
return mcp.NewToolResultText("Long operation completed successfully"), nil
}
Custom Notifications
type CustomNotifier struct {
sessions map[string]chan mcp.Notification
mutex sync.RWMutex
}
func NewCustomNotifier() *CustomNotifier {
return &CustomNotifier{
sessions: make(map[string]chan mcp.Notification),
}
}
func (n *CustomNotifier) RegisterSession(sessionID string) {
n.mutex.Lock()
defer n.mutex.Unlock()
n.sessions[sessionID] = make(chan mcp.Notification, 100)
}
func (n *CustomNotifier) UnregisterSession(sessionID string) {
n.mutex.Lock()
defer n.mutex.Unlock()
if ch, exists := n.sessions[sessionID]; exists {
close(ch)
delete(n.sessions, sessionID)
}
}
func (n *CustomNotifier) SendAlert(sessionID, message string, severity string) {
n.mutex.RLock()
defer n.mutex.RUnlock()
if ch, exists := n.sessions[sessionID]; exists {
select {
case ch <- mcp.Notification{
Type: "alert",
Data: map[string]interface{}{
"message": message,
"severity": severity,
"timestamp": time.Now().Unix(),
},
}:
default:
// Channel full, drop notification
}
}
}
func (n *CustomNotifier) BroadcastSystemMessage(message string) {
n.mutex.RLock()
defer n.mutex.RUnlock()
notification := mcp.Notification{
Type: "system_message",
Data: map[string]interface{}{
"message": message,
"timestamp": time.Now().Unix(),
},
}
for _, ch := range n.sessions {
select {
case ch <- notification:
default:
// Channel full, skip this session
}
}
}
Production Configuration
Complete Production Server
func main() {
// Initialize components
logger := log.New(os.Stdout, "[MCP] ", log.LstdFlags)
metrics := NewPrometheusMetrics()
sessionManager := NewSessionManager()
notifier := NewCustomNotifier()
// Create middleware
loggingMW := NewLoggingMiddleware(logger)
rateLimitMW := NewRateLimitMiddleware(10.0, 20) // 10 req/sec, burst 20
authMW := NewAuthMiddleware(NewJWTValidator())
// Create hooks
telemetryHooks := NewTelemetryHooks(metrics, logger)
businessHooks := NewBusinessHooks(NewAuditLogger(), NewSlackNotifier())
// Create server with all features
s := server.NewMCPServer("Production Server", "1.0.0",
server.WithToolCapabilities(true),
server.WithResourceCapabilities(false, true),
server.WithPromptCapabilities(true),
server.WithRecovery(),
server.WithHooks(telemetryHooks),
server.WithToolHandlerMiddleware(loggingMW.ToolMiddleware),
server.WithToolFilter(NewPermissionFilter(sessionManager)),
)
// Add tools and resources
addProductionTools(s)
addProductionResources(s)
addProductionPrompts(s)
// Start server with graceful shutdown
startWithGracefulShutdown(s)
}
func startWithGracefulShutdown(s *server.MCPServer) {
// Setup signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
// Start server in goroutine
go func() {
if err := server.ServeStdio(s); err != nil {
log.Printf("Server error: %v", err)
}
}()
// Wait for shutdown signal
<-sigChan
log.Println("Shutting down server...")
// Graceful shutdown with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.Shutdown(ctx); err != nil {
log.Printf("Shutdown error: %v", err)
}
log.Println("Server stopped")
}
Client Capability Based Filtering
package main
import (
"context"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
func main() {
s := server.NewMCPServer("Typed Server", "1.0.0",
server.WithToolCapabilities(true),
)
s.AddTool(
mcp.NewTool("calculate",
mcp.WithDescription("Perform basic mathematical calculations"),
mcp.WithString("operation",
mcp.Required(),
mcp.Enum("add", "subtract", "multiply", "divide"),
mcp.Description("The operation to perform"),
),
mcp.WithNumber("x", mcp.Required(), mcp.Description("First number")),
mcp.WithNumber("y", mcp.Required(), mcp.Description("Second number")),
),
handleCalculate,
)
server.ServeStdio(s)
}
func handleCalculate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
session := server.ClientSessionFromContext(ctx)
if session == nil {
return nil, fmt.Errorf("no active session")
}
if clientSession, ok := session.(server.SessionWithClientInfo); ok {
clientCapabilities := clientSession.GetClientCapabilities()
if clientCapabilities.Sampling == nil {
fmt.Println("sampling is not enabled in client")
}
}
// TODO: implement calculation logic
return mcp.NewToolResultError("not implemented"), nil
}
Sampling (Advanced)
Sampling is an advanced feature that allows servers to request LLM completions from clients. This enables bidirectional communication where servers can leverage client-side LLM capabilities for content generation, reasoning, and question answering.
Note: Sampling is an advanced feature that most servers don't need. Only implement sampling if your server specifically needs to generate content using the client's LLM.
When to Use Sampling
Consider sampling when your server needs to:
- Generate content based on user input
- Answer questions using LLM reasoning
- Perform text analysis or summarization
- Create dynamic responses that require LLM capabilities
Basic Implementation
// Enable sampling capability
mcpServer.EnableSampling()
// Add a tool that uses sampling
mcpServer.AddTool(mcp.Tool{
Name: "ask_llm",
Description: "Ask the LLM a question using sampling",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"question": map[string]any{
"type": "string",
"description": "The question to ask the LLM",
},
},
Required: []string{"question"},
},
}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
question, err := request.RequireString("question")
if err != nil {
return nil, err
}
// Create sampling request
samplingRequest := mcp.CreateMessageRequest{
CreateMessageParams: mcp.CreateMessageParams{
Messages: []mcp.SamplingMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Type: "text",
Text: question,
},
},
},
SystemPrompt: "You are a helpful assistant.",
MaxTokens: 1000,
Temperature: 0.7,
},
}
// Request sampling from client
result, err := mcpServer.RequestSampling(ctx, samplingRequest)
if err != nil {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: fmt.Sprintf("Error: %v", err),
},
},
IsError: true,
}, nil
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: fmt.Sprintf("LLM Response: %s", result.Content),
},
},
}, nil
})
For complete sampling documentation, see Server Sampling Guide.
Next Steps
- Client Development - Learn to build MCP clients
- Server Basics - Review fundamental concepts