RRust By Example

Rust AI Inference Production Guide

Complete guide to deploying AI inference services in Rust to production. Covers health checks, graceful shutdown, model hot-reload, observability, and zero-downtime deployments.

Topic: Ai Inference

Search intent: High-intent search: "rust ai inference production deployment"

Rust AI Inference Production Guide

Pre-deployment checklist

  • [ ] Model files are checksummed and verified at startup.
  • [ ] All inference paths have timeouts configured.
  • [ ] Memory limits are set (RLIMIT_AS or container limits).
  • [ ] Health check endpoint returns model readiness, not just process liveness.
  • [ ] Graceful shutdown drains in-flight requests before exiting.
  • [ ] Structured logging with request IDs for distributed tracing.
  • [ ] Prometheus metrics exposed on /metrics.

Runnable example — production-ready inference server skeleton

rust
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::signal;
use tokio::sync::Semaphore;

/// Shared server state
struct AppState {
    model_ready: AtomicBool,
    requests_total: AtomicU64,
    requests_in_flight: AtomicU64,
    concurrency_limit: Semaphore,
}

impl AppState {
    fn new(max_concurrent: usize) -> Arc<Self> {
        Arc::new(Self {
            model_ready: AtomicBool::new(false),
            requests_total: AtomicU64::new(0),
            requests_in_flight: AtomicU64::new(0),
            concurrency_limit: Semaphore::new(max_concurrent),
        })
    }

    fn mark_ready(&self) {
        self.model_ready.store(true, Ordering::Release);
    }

    fn is_ready(&self) -> bool {
        self.model_ready.load(Ordering::Acquire)
    }
}

/// Simulate model loading (replace with actual model initialization)
async fn load_model(state: Arc<AppState>) {
    println!("[startup] Loading model weights...");
    tokio::time::sleep(Duration::from_millis(500)).await; // Simulate load time
    state.mark_ready();
    println!("[startup] Model ready.");
}

/// Handle one inference request with timeout and concurrency limiting
async fn handle_inference(
    state: Arc<AppState>,
    input: Vec<f32>,
    request_id: u64,
) -> Result<Vec<f32>, &'static str> {
    if !state.is_ready() {
        return Err("model not ready");
    }

    state.requests_in_flight.fetch_add(1, Ordering::Relaxed);
    state.requests_total.fetch_add(1, Ordering::Relaxed);

    let _permit = tokio::time::timeout(
        Duration::from_secs(10),
        state.concurrency_limit.acquire(),
    )
    .await
    .map_err(|_| "timeout waiting for slot")?
    .map_err(|_| "semaphore closed")?;

    let start = Instant::now();

    // Simulate inference with timeout
    let result = tokio::time::timeout(
        Duration::from_secs(5),
        tokio::task::spawn_blocking(move || {
            // CPU-bound inference work
            input.iter().map(|x| x * 2.0 + 0.1).collect::<Vec<f32>>()
        }),
    )
    .await
    .map_err(|_| "inference timeout")?
    .map_err(|_| "task panicked")?;

    let elapsed = start.elapsed();
    state.requests_in_flight.fetch_sub(1, Ordering::Relaxed);

    println!(
        "[req={}] inference completed in {:.2}ms",
        request_id,
        elapsed.as_secs_f64() * 1000.0
    );

    Ok(result)
}

/// Graceful shutdown: wait for in-flight requests to drain
async fn graceful_shutdown(state: Arc<AppState>) {
    println!("[shutdown] Waiting for in-flight requests to drain...");
    let deadline = Instant::now() + Duration::from_secs(30);
    loop {
        let in_flight = state.requests_in_flight.load(Ordering::Relaxed);
        if in_flight == 0 { break; }
        if Instant::now() > deadline {
            println!("[shutdown] Deadline reached with {} requests in flight", in_flight);
            break;
        }
        tokio::time::sleep(Duration::from_millis(100)).await;
    }
    println!("[shutdown] Drain complete.");
}

#[tokio::main]
async fn main() {
    let state = AppState::new(16);

    // Load model asynchronously at startup
    load_model(state.clone()).await;

    // Simulate serving requests
    let handles: Vec<_> = (0..5).map(|i| {
        let state = state.clone();
        tokio::spawn(async move {
            let input = vec![1.0f32, 2.0, 3.0];
            match handle_inference(state, input, i).await {
                Ok(output) => println!("[req={}] output: {:?}", i, output),
                Err(e) => println!("[req={}] error: {}", i, e),
            }
        })
    }).collect();

    for h in handles { h.await.unwrap(); }

    // In a real server, wait for SIGTERM:
    // signal::ctrl_c().await.unwrap();
    graceful_shutdown(state.clone()).await;

    println!(
        "[stats] Total requests processed: {}",
        state.requests_total.load(Ordering::Relaxed)
    );
}

Model hot-reload pattern

rust
use std::sync::Arc;
use tokio::sync::watch;

#[derive(Clone)]
struct ModelVersion {
    version: u32,
    // weights: Arc<ModelWeights>,  // real weights would go here
}

async fn model_reloader(tx: watch::Sender<ModelVersion>) {
    let mut version = 1u32;
    loop {
        tokio::time::sleep(Duration::from_secs(60)).await;
        version += 1;
        println!("[reload] Loading model version {}", version);
        // Load new weights here, then swap atomically
        let _ = tx.send(ModelVersion { version });
        println!("[reload] Model version {} is live", version);
    }
}

async fn inference_worker(mut rx: watch::Receiver<ModelVersion>) {
    loop {
        let model = rx.borrow().clone();
        println!("[worker] Using model version {}", model.version);
        tokio::time::sleep(Duration::from_secs(5)).await;
        rx.changed().await.ok();
    }
}

Observability setup

rust
// Prometheus metrics pattern (with metrics crate)
// counter!("inference_requests_total", "model" => model_name);
// histogram!("inference_duration_seconds", duration.as_secs_f64());
// gauge!("inference_queue_depth", queue_len as f64);

Related reading

Related Guides

Continue in This Topic

More Rust Guides