Skip to content

Advanced Server Features

Explore powerful features that make MCP-Go servers production-ready: typed tools, session management, middleware, hooks, and more.

Typed Tools

Typed tools provide compile-time type safety and automatic parameter validation, reducing boilerplate and preventing runtime errors.

Basic Typed Tool

package main
 
import (
    "context"
    "encoding/json"
    "fmt"
    "time"
 
    "github.com/mark3labs/mcp-go/mcp"
    "github.com/mark3labs/mcp-go/server"
)
 
// Define input and output types
type CalculateInput struct {
    Operation string  `json:"operation" validate:"required,oneof=add subtract multiply divide"`
    X         float64 `json:"x" validate:"required"`
    Y         float64 `json:"y" validate:"required"`
}
 
type CalculateOutput struct {
    Result    float64 `json:"result"`
    Operation string  `json:"operation"`
}
 
func main() {
    s := server.NewMCPServer("Typed Server", "1.0.0",
        server.WithToolCapabilities(true),
    )
 
    // Create typed tool
    tool := mcp.NewTool("calculate",
        mcp.WithDescription("Perform arithmetic operations"),
        mcp.WithString("operation", mcp.Required()),
        mcp.WithNumber("x", mcp.Required()),
        mcp.WithNumber("y", mcp.Required()),
    )
    
    // Add tool with typed handler
    s.AddTool(tool, mcp.NewTypedToolHandler(handleCalculateTyped))
 
    server.ServeStdio(s)
}
 
func handleCalculateTyped(ctx context.Context, req mcp.CallToolRequest, input CalculateInput) (*mcp.CallToolResult, error) {
    var result float64
    
    switch input.Operation {
    case "add":
        result = input.X + input.Y
    case "subtract":
        result = input.X - input.Y
    case "multiply":
        result = input.X * input.Y
    case "divide":
        if input.Y == 0 {
            return mcp.NewToolResultError("division by zero"), nil
        }
        result = input.X / input.Y
    }
 
    output := CalculateOutput{
        Result:    result,
        Operation: input.Operation,
    }
    
    jsonData, err := json.Marshal(output)
    if err != nil {
        return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil
    }
    
    return mcp.NewToolResultText(string(jsonData)), nil
}

Complex Typed Tool

type UserCreateInput struct {
    Name     string            `json:"name" validate:"required,min=1,max=100"`
    Email    string            `json:"email" validate:"required,email"`
    Age      int               `json:"age" validate:"min=0,max=150"`
    Tags     []string          `json:"tags" validate:"dive,min=1"`
    Metadata map[string]string `json:"metadata"`
    Active   bool              `json:"active"`
}
 
type UserCreateOutput struct {
    ID        string    `json:"id"`
    Name      string    `json:"name"`
    Email     string    `json:"email"`
    CreatedAt time.Time `json:"created_at"`
    Status    string    `json:"status"`
}
 
func handleCreateUser(ctx context.Context, req mcp.CallToolRequest, input UserCreateInput) (*mcp.CallToolResult, error) {
    // Validation is automatic based on struct tags
    
    // Create user in database
    user := &User{
        ID:        generateID(),
        Name:      input.Name,
        Email:     input.Email,
        Age:       input.Age,
        Tags:      input.Tags,
        Metadata:  input.Metadata,
        Active:    input.Active,
        CreatedAt: time.Now(),
    }
 
    if err := db.Create(user); err != nil {
        return nil, fmt.Errorf("failed to create user: %w", err)
    }
 
    output := &UserCreateOutput{
        ID:        user.ID,
        Name:      user.Name,
        Email:     user.Email,
        CreatedAt: user.CreatedAt,
        Status:    "created",
    }
 
    return mcp.NewToolResultJSON(output), nil
}

Custom Validation

import (
    "path/filepath"
    "strings"
 
    "github.com/go-playground/validator/v10"
)
 
type FileOperationInput struct {
    Path      string `json:"path" validate:"required,filepath"`
    Operation string `json:"operation" validate:"required,oneof=read write delete"`
    Content   string `json:"content" validate:"required_if=Operation write"`
}
 
// Custom validator
func init() {
    validate := validator.New()
    validate.RegisterValidation("filepath", validateFilePath)
}
 
func validateFilePath(fl validator.FieldLevel) bool {
    path := fl.Field().String()
    
    // Prevent directory traversal
    if strings.Contains(path, "..") {
        return false
    }
    
    // Ensure path is within allowed directory
    allowedDir := "/app/data"
    absPath, err := filepath.Abs(path)
    if err != nil {
        return false
    }
    
    return strings.HasPrefix(absPath, allowedDir)
}

Session Management

Handle multiple clients with per-session state and tools.

Per-Session State

type SessionState struct {
    UserID      string
    Permissions []string
    Settings    map[string]interface{}
    StartTime   time.Time
}
 
type SessionManager struct {
    sessions map[string]*SessionState
    mutex    sync.RWMutex
}
 
func NewSessionManager() *SessionManager {
    return &SessionManager{
        sessions: make(map[string]*SessionState),
    }
}
 
func (sm *SessionManager) CreateSession(sessionID, userID string, permissions []string) {
    sm.mutex.Lock()
    defer sm.mutex.Unlock()
    
    sm.sessions[sessionID] = &SessionState{
        UserID:      userID,
        Permissions: permissions,
        Settings:    make(map[string]interface{}),
        StartTime:   time.Now(),
    }
}
 
func (sm *SessionManager) GetSession(sessionID string) (*SessionState, bool) {
    sm.mutex.RLock()
    defer sm.mutex.RUnlock()
    
    session, exists := sm.sessions[sessionID]
    return session, exists
}
 
func (sm *SessionManager) RemoveSession(sessionID string) {
    sm.mutex.Lock()
    defer sm.mutex.Unlock()
    
    delete(sm.sessions, sessionID)
}

Session-Aware Tools

func main() {
    sessionManager := NewSessionManager()
    
    hooks := &server.Hooks{}
    
    hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) {
        // Initialize session with default permissions
        sessionManager.CreateSession(session.ID(), "anonymous", []string{"read"})
        log.Printf("Session %s started", session.ID())
    })
    
    hooks.AddOnUnregisterSession(func(ctx context.Context, session server.ClientSession) {
        sessionManager.RemoveSession(session.ID())
        log.Printf("Session %s ended", session.ID())
    })
    
    s := server.NewMCPServer("Session Server", "1.0.0",
        server.WithToolCapabilities(true),
        server.WithHooks(hooks),
    )
 
    // Add session-aware tool
    s.AddTool(
        mcp.NewTool("get_user_data",
            mcp.WithDescription("Get user-specific data"),
            mcp.WithString("data_type", mcp.Required()),
        ),
        createSessionAwareTool(sessionManager),
    )
 
    server.ServeStdio(s)
}
 
func createSessionAwareTool(sm *SessionManager) server.ToolHandler {
    return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
        sessionID := server.GetSessionID(ctx)
        session, exists := sm.GetSession(sessionID)
        if !exists {
            return nil, fmt.Errorf("invalid session")
        }
 
        dataType := req.Params.Arguments["data_type"].(string)
        
        // Check permissions
        if !hasPermission(session.Permissions, "read") {
            return nil, fmt.Errorf("insufficient permissions")
        }
 
        // Get user-specific data
        data, err := getUserData(session.UserID, dataType)
        if err != nil {
            return nil, err
        }
 
        jsonData, err := json.Marshal(data)
        if err != nil {
            return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil
        }
        
        return mcp.NewToolResultText(string(jsonData)), nil
    }
}

Middleware

Add cross-cutting concerns like logging, authentication, and rate limiting.

Logging Middleware

type LoggingMiddleware struct {
    logger *log.Logger
}
 
func NewLoggingMiddleware(logger *log.Logger) *LoggingMiddleware {
    return &LoggingMiddleware{logger: logger}
}
 
func (m *LoggingMiddleware) ToolMiddleware(next server.ToolHandlerFunc) server.ToolHandlerFunc {
    return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
        start := time.Now()
        sessionID := server.GetSessionID(ctx)
        
        m.logger.Printf("Tool call started: tool=%s", req.Params.Name)
        
        result, err := next(ctx, req)
        
        duration := time.Since(start)
        if err != nil {
            m.logger.Printf("Tool call failed: session=%s tool=%s duration=%v error=%v", 
                sessionID, req.Params.Name, duration, err)
        } else {
            m.logger.Printf("Tool call completed: session=%s tool=%s duration=%v", 
                sessionID, req.Params.Name, duration)
        }
        
        return result, err
    }
}
 
func (m *LoggingMiddleware) ResourceMiddleware(next server.ResourceHandler) server.ResourceHandler {
    return func(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
        start := time.Now()
        sessionID := server.GetSessionID(ctx)
        
        m.logger.Printf("Resource read started: session=%s uri=%s", sessionID, req.Params.URI)
        
        result, err := next(ctx, req)
        
        duration := time.Since(start)
        if err != nil {
            m.logger.Printf("Resource read failed: session=%s uri=%s duration=%v error=%v", 
                sessionID, req.Params.URI, duration, err)
        } else {
            m.logger.Printf("Resource read completed: session=%s uri=%s duration=%v", 
                sessionID, req.Params.URI, duration)
        }
        
        return result, err
    }
}

Rate Limiting Middleware

type RateLimitMiddleware struct {
    limiters map[string]*rate.Limiter
    mutex    sync.RWMutex
    rate     rate.Limit
    burst    int
}
 
func NewRateLimitMiddleware(requestsPerSecond float64, burst int) *RateLimitMiddleware {
    return &RateLimitMiddleware{
        limiters: make(map[string]*rate.Limiter),
        rate:     rate.Limit(requestsPerSecond),
        burst:    burst,
    }
}
 
func (m *RateLimitMiddleware) getLimiter(sessionID string) *rate.Limiter {
    m.mutex.RLock()
    limiter, exists := m.limiters[sessionID]
    m.mutex.RUnlock()
    
    if !exists {
        m.mutex.Lock()
        limiter = rate.NewLimiter(m.rate, m.burst)
        m.limiters[sessionID] = limiter
        m.mutex.Unlock()
    }
    
    return limiter
}
 
func (m *RateLimitMiddleware) ToolMiddleware(next server.ToolHandler) server.ToolHandler {
    return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
        sessionID := server.GetSessionID(ctx)
        limiter := m.getLimiter(sessionID)
        
        if !limiter.Allow() {
            return nil, fmt.Errorf("rate limit exceeded for session %s", sessionID)
        }
        
        return next(ctx, req)
    }
}

Authentication Middleware

type AuthMiddleware struct {
    tokenValidator TokenValidator
}
 
func NewAuthMiddleware(validator TokenValidator) *AuthMiddleware {
    return &AuthMiddleware{tokenValidator: validator}
}
 
func (m *AuthMiddleware) ToolMiddleware(next server.ToolHandler) server.ToolHandler {
    return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
        // Extract token from context or request
        token := extractToken(ctx, req)
        if token == "" {
            return nil, fmt.Errorf("authentication required")
        }
        
        // Validate token
        user, err := m.tokenValidator.Validate(token)
        if err != nil {
            return nil, fmt.Errorf("invalid token: %w", err)
        }
        
        // Add user to context
        ctx = context.WithValue(ctx, "user", user)
        
        return next(ctx, req)
    }
}

Hooks

Implement lifecycle callbacks for telemetry, logging, and custom behavior.

Comprehensive Hooks

type TelemetryHooks struct {
    metrics MetricsCollector
    logger  *log.Logger
}
 
func NewTelemetryHooks(metrics MetricsCollector, logger *log.Logger) *TelemetryHooks {
    return &TelemetryHooks{
        metrics: metrics,
        logger:  logger,
    }
}
 
func (h *TelemetryHooks) OnServerStart() {
    h.logger.Println("MCP Server starting")
    h.metrics.Increment("server.starts")
}
 
func (h *TelemetryHooks) OnServerStop() {
    h.logger.Println("MCP Server stopping")
    h.metrics.Increment("server.stops")
}
 
func (h *TelemetryHooks) OnSessionStart(sessionID string) {
    h.logger.Printf("Session started: %s", sessionID)
    h.metrics.Increment("sessions.started")
    h.metrics.Gauge("sessions.active", h.getActiveSessionCount())
}
 
func (h *TelemetryHooks) OnSessionEnd(sessionID string) {
    h.logger.Printf("Session ended: %s", sessionID)
    h.metrics.Increment("sessions.ended")
    h.metrics.Gauge("sessions.active", h.getActiveSessionCount())
}
 
func (h *TelemetryHooks) OnToolCall(sessionID, toolName string, duration time.Duration, err error) {
    h.metrics.Increment("tools.calls", map[string]string{
        "tool":    toolName,
        "session": sessionID,
    })
    h.metrics.Histogram("tools.duration", duration.Seconds(), map[string]string{
        "tool": toolName,
    })
    
    if err != nil {
        h.metrics.Increment("tools.errors", map[string]string{
            "tool": toolName,
        })
    }
}
 
func (h *TelemetryHooks) OnResourceRead(sessionID, uri string, duration time.Duration, err error) {
    h.metrics.Increment("resources.reads", map[string]string{
        "session": sessionID,
    })
    h.metrics.Histogram("resources.duration", duration.Seconds())
    
    if err != nil {
        h.metrics.Increment("resources.errors")
    }
}

Custom Business Logic Hooks

type BusinessHooks struct {
    auditLogger AuditLogger
    notifier    Notifier
}
 
func (h *BusinessHooks) OnToolCall(sessionID, toolName string, duration time.Duration, err error) {
    // Audit sensitive operations
    if isSensitiveTool(toolName) {
        h.auditLogger.LogToolCall(sessionID, toolName, err)
    }
    
    // Alert on errors
    if err != nil {
        h.notifier.SendAlert(fmt.Sprintf("Tool %s failed for session %s: %v", 
            toolName, sessionID, err))
    }
    
    // Monitor performance
    if duration > 30*time.Second {
        h.notifier.SendAlert(fmt.Sprintf("Slow tool execution: %s took %v", 
            toolName, duration))
    }
}
 
func (h *BusinessHooks) OnSessionStart(sessionID string) {
    // Initialize user-specific resources
    h.initializeUserResources(sessionID)
    
    // Send welcome notification
    h.notifier.SendWelcome(sessionID)
}
 
func (h *BusinessHooks) OnSessionEnd(sessionID string) {
    // Cleanup user resources
    h.cleanupUserResources(sessionID)
    
    // Log session summary
    h.auditLogger.LogSessionEnd(sessionID)
}

Tool Filtering

Conditionally expose tools based on context, permissions, or other criteria.

Permission-Based Filtering

type PermissionFilter struct {
    sessionManager *SessionManager
}
 
func NewPermissionFilter(sm *SessionManager) *PermissionFilter {
    return &PermissionFilter{sessionManager: sm}
}
 
func (f *PermissionFilter) FilterTools(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
    sessionID := server.GetSessionID(ctx)
    session, exists := f.sessionManager.GetSession(sessionID)
    if !exists {
        return []mcp.Tool{} // No tools for invalid sessions
    }
 
    var filtered []mcp.Tool
    for _, tool := range tools {
        if f.hasPermissionForTool(session, tool.Name) {
            filtered = append(filtered, tool)
        }
    }
    
    return filtered
}
 
func (f *PermissionFilter) hasPermissionForTool(session *SessionState, toolName string) bool {
    requiredPermissions := map[string][]string{
        "delete_user":    {"admin"},
        "modify_system":  {"admin", "operator"},
        "read_data":      {"admin", "operator", "user"},
        "create_report":  {"admin", "operator", "user"},
    }
    
    required, exists := requiredPermissions[toolName]
    if !exists {
        return true // Allow tools without specific requirements
    }
    
    for _, permission := range session.Permissions {
        for _, req := range required {
            if permission == req {
                return true
            }
        }
    }
    
    return false
}

Context-Based Filtering

type ContextFilter struct{}
 
func (f *ContextFilter) FilterTools(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
    timeOfDay := time.Now().Hour()
    environment := os.Getenv("ENVIRONMENT")
    
    var filtered []mcp.Tool
    for _, tool := range tools {
        if f.shouldIncludeTool(tool, timeOfDay, environment) {
            filtered = append(filtered, tool)
        }
    }
    
    return filtered
}
 
func (f *ContextFilter) shouldIncludeTool(tool mcp.Tool, hour int, env string) bool {
    // Maintenance tools only during off-hours
    maintenanceTools := map[string]bool{
        "backup_database":  true,
        "cleanup_logs":     true,
        "restart_service":  true,
    }
    
    if maintenanceTools[tool.Name] {
        return hour < 6 || hour > 22 // Only between 10 PM and 6 AM
    }
    
    // Debug tools only in development
    debugTools := map[string]bool{
        "debug_session": true,
        "dump_state":    true,
    }
    
    if debugTools[tool.Name] {
        return env == "development"
    }
    
    return true
}

Notifications

Send server-to-client messages for real-time updates.

Custom Notifications

func handleLongRunningTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
    srv := server.ServerFromContext(ctx)
    
    // Simulate long-running work
    for i := 0; i < 100; i++ {
        time.Sleep(100 * time.Millisecond)
        
        // Send custom notification to all clients
        notification := map[string]interface{}{
            "type":     "progress",
            "progress": i + 1,
            "total":    100,
            "message":  fmt.Sprintf("Processing step %d/100", i+1),
        }
        
        srv.SendNotificationToAllClients("progress", notification)
    }
    
    return mcp.NewToolResultText("Long operation completed successfully"), nil
}

Custom Notifications

type CustomNotifier struct {
    sessions map[string]chan mcp.Notification
    mutex    sync.RWMutex
}
 
func NewCustomNotifier() *CustomNotifier {
    return &CustomNotifier{
        sessions: make(map[string]chan mcp.Notification),
    }
}
 
func (n *CustomNotifier) RegisterSession(sessionID string) {
    n.mutex.Lock()
    defer n.mutex.Unlock()
    
    n.sessions[sessionID] = make(chan mcp.Notification, 100)
}
 
func (n *CustomNotifier) UnregisterSession(sessionID string) {
    n.mutex.Lock()
    defer n.mutex.Unlock()
    
    if ch, exists := n.sessions[sessionID]; exists {
        close(ch)
        delete(n.sessions, sessionID)
    }
}
 
func (n *CustomNotifier) SendAlert(sessionID, message string, severity string) {
    n.mutex.RLock()
    defer n.mutex.RUnlock()
    
    if ch, exists := n.sessions[sessionID]; exists {
        select {
        case ch <- mcp.Notification{
            Type: "alert",
            Data: map[string]interface{}{
                "message":  message,
                "severity": severity,
                "timestamp": time.Now().Unix(),
            },
        }:
        default:
            // Channel full, drop notification
        }
    }
}
 
func (n *CustomNotifier) BroadcastSystemMessage(message string) {
    n.mutex.RLock()
    defer n.mutex.RUnlock()
    
    notification := mcp.Notification{
        Type: "system_message",
        Data: map[string]interface{}{
            "message":   message,
            "timestamp": time.Now().Unix(),
        },
    }
    
    for _, ch := range n.sessions {
        select {
        case ch <- notification:
        default:
            // Channel full, skip this session
        }
    }
}

Production Configuration

Complete Production Server

func main() {
    // Initialize components
    logger := log.New(os.Stdout, "[MCP] ", log.LstdFlags)
    metrics := NewPrometheusMetrics()
    sessionManager := NewSessionManager()
    notifier := NewCustomNotifier()
    
    // Create middleware
    loggingMW := NewLoggingMiddleware(logger)
    rateLimitMW := NewRateLimitMiddleware(10.0, 20) // 10 req/sec, burst 20
    authMW := NewAuthMiddleware(NewJWTValidator())
    
    // Create hooks
    telemetryHooks := NewTelemetryHooks(metrics, logger)
    businessHooks := NewBusinessHooks(NewAuditLogger(), NewSlackNotifier())
    
    // Create server with all features
    s := server.NewMCPServer("Production Server", "1.0.0",
        server.WithToolCapabilities(true),
        server.WithResourceCapabilities(false, true),
        server.WithPromptCapabilities(true),
        server.WithRecovery(),
        server.WithHooks(telemetryHooks),
        server.WithToolHandlerMiddleware(loggingMW.ToolMiddleware),
        server.WithToolFilter(NewPermissionFilter(sessionManager)),
    )
    
    // Add tools and resources
    addProductionTools(s)
    addProductionResources(s)
    addProductionPrompts(s)
    
    // Start server with graceful shutdown
    startWithGracefulShutdown(s)
}
 
func startWithGracefulShutdown(s *server.MCPServer) {
    // Setup signal handling
    sigChan := make(chan os.Signal, 1)
    signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
    
    // Start server in goroutine
    go func() {
        if err := server.ServeStdio(s); err != nil {
            log.Printf("Server error: %v", err)
        }
    }()
    
    // Wait for shutdown signal
    <-sigChan
    log.Println("Shutting down server...")
    
    // Graceful shutdown with timeout
    ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
    defer cancel()
    
    if err := s.Shutdown(ctx); err != nil {
        log.Printf("Shutdown error: %v", err)
    }
    
    log.Println("Server stopped")
}

Next Steps