RRust By Example

Rust AI Inference Error Playbook

Error handling patterns for AI inference services in Rust. Covers typed errors, retry logic, circuit breakers, fallback models, and graceful degradation under failure.

Topic: Ai Inference

Search intent: High-intent search: "rust ai inference error handling"

Rust AI Inference Error Playbook

Error taxonomy

rust
use std::fmt;

#[derive(Debug, Clone, PartialEq)]
pub enum InferenceError {
    /// Input failed validation before reaching the model
    InvalidInput(String),
    /// Model not loaded or not ready
    ModelNotReady,
    /// Inference took longer than the configured timeout
    Timeout { duration_ms: u64 },
    /// Internal model error (NaN output, shape mismatch)
    ModelError(String),
    /// Downstream dependency failed (GPU OOM, driver error)
    InfrastructureError(String),
    /// Too many concurrent requests
    Overloaded { queue_depth: usize },
}

impl fmt::Display for InferenceError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::InvalidInput(msg) => write!(f, "invalid input: {}", msg),
            Self::ModelNotReady => write!(f, "model is not ready yet"),
            Self::Timeout { duration_ms } =>
                write!(f, "inference timed out after {}ms", duration_ms),
            Self::ModelError(msg) => write!(f, "model error: {}", msg),
            Self::InfrastructureError(msg) => write!(f, "infrastructure error: {}", msg),
            Self::Overloaded { queue_depth } =>
                write!(f, "service overloaded (queue depth: {})", queue_depth),
        }
    }
}

impl InferenceError {
    /// HTTP status code for this error
    pub fn http_status(&self) -> u16 {
        match self {
            Self::InvalidInput(_) => 400,
            Self::ModelNotReady => 503,
            Self::Timeout { .. } => 504,
            Self::ModelError(_) => 500,
            Self::InfrastructureError(_) => 503,
            Self::Overloaded { .. } => 429,
        }
    }

    /// Whether this error is safe to retry
    pub fn is_retryable(&self) -> bool {
        matches!(self, Self::Overloaded { .. } | Self::Timeout { .. } | Self::ModelNotReady)
    }
}

Retry with exponential backoff

rust
use std::time::Duration;

async fn infer_with_retry(
    input: Vec<f32>,
    max_retries: u32,
) -> Result<Vec<f32>, InferenceError> {
    let mut last_error = None;

    for attempt in 0..=max_retries {
        if attempt > 0 {
            let delay = Duration::from_millis(50 * 2u64.pow(attempt - 1));
            tokio::time::sleep(delay).await;
            println!("Retry attempt {} after {:?}", attempt, delay);
        }

        match do_infer(&input).await {
            Ok(result) => return Ok(result),
            Err(e) if e.is_retryable() => {
                last_error = Some(e);
                continue;
            }
            Err(e) => return Err(e), // Non-retryable: fail fast
        }
    }

    Err(last_error.unwrap())
}

async fn do_infer(input: &[f32]) -> Result<Vec<f32>, InferenceError> {
    // Simulate occasional overload
    use std::time::{SystemTime, UNIX_EPOCH};
    let t = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().subsec_nanos();
    if t % 3 == 0 {
        return Err(InferenceError::Overloaded { queue_depth: 150 });
    }
    Ok(input.iter().map(|x| x * 2.0).collect())
}

#[tokio::main]
async fn main() {
    match infer_with_retry(vec![1.0, 2.0, 3.0], 3).await {
        Ok(out) => println!("Success: {:?}", out),
        Err(e) => println!("Failed after retries: {} (HTTP {})", e, e.http_status()),
    }
}

Circuit breaker pattern

rust
use std::sync::atomic::{AtomicU32, AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::sync::Mutex;

struct CircuitBreaker {
    failure_count: AtomicU32,
    failure_threshold: u32,
    open: AtomicBool,
    opened_at: Mutex<Option<Instant>>,
    recovery_timeout: Duration,
}

impl CircuitBreaker {
    fn new(threshold: u32, recovery_timeout: Duration) -> Arc<Self> {
        Arc::new(Self {
            failure_count: AtomicU32::new(0),
            failure_threshold: threshold,
            open: AtomicBool::new(false),
            opened_at: Mutex::new(None),
            recovery_timeout,
        })
    }

    fn is_open(&self) -> bool {
        if !self.open.load(Ordering::Acquire) { return false; }
        // Check if recovery timeout has elapsed (half-open state)
        let opened_at = self.opened_at.lock().unwrap();
        if let Some(t) = *opened_at {
            if t.elapsed() > self.recovery_timeout {
                return false; // Allow one probe request
            }
        }
        true
    }

    fn record_success(&self) {
        self.failure_count.store(0, Ordering::Release);
        self.open.store(false, Ordering::Release);
    }

    fn record_failure(&self) {
        let count = self.failure_count.fetch_add(1, Ordering::AcqRel) + 1;
        if count >= self.failure_threshold {
            self.open.store(true, Ordering::Release);
            *self.opened_at.lock().unwrap() = Some(Instant::now());
            println!("⚡ Circuit breaker OPEN after {} failures", count);
        }
    }

    async fn call<F, Fut>(&self, f: F) -> Result<Vec<f32>, InferenceError>
    where
        F: FnOnce() -> Fut,
        Fut: std::future::Future<Output = Result<Vec<f32>, InferenceError>>,
    {
        if self.is_open() {
            return Err(InferenceError::InfrastructureError("circuit breaker open".into()));
        }
        match f().await {
            Ok(r) => { self.record_success(); Ok(r) }
            Err(e) => { self.record_failure(); Err(e) }
        }
    }
}

#[tokio::main]
async fn main() {
    let cb = CircuitBreaker::new(3, Duration::from_secs(30));

    for i in 0..6 {
        let result = cb.call(|| async move {
            if i < 4 {
                Err(InferenceError::InfrastructureError("GPU OOM".into()))
            } else {
                Ok(vec![1.0, 2.0])
            }
        }).await;
        println!("Call {}: {:?}", i, result.map_err(|e| e.to_string()));
    }
}

Fallback model pattern

rust
async fn infer_with_fallback(input: Vec<f32>) -> Vec<f32> {
    // Try primary (large, accurate) model first
    match run_primary_model(&input).await {
        Ok(result) => result,
        Err(e) => {
            eprintln!("Primary model failed ({}), falling back to lightweight model", e);
            run_fallback_model(&input).await
        }
    }
}

async fn run_primary_model(input: &[f32]) -> Result<Vec<f32>, InferenceError> {
    // Simulate primary failure
    Err(InferenceError::Timeout { duration_ms: 5000 })
}

async fn run_fallback_model(input: &[f32]) -> Vec<f32> {
    // Fast, less accurate model
    input.iter().map(|x| x * 1.5).collect()
}

#[tokio::main]
async fn main() {
    let result = infer_with_fallback(vec![1.0, 2.0, 3.0]).await;
    println!("Final result: {:?}", result);
}

Related reading

Related Guides

Continue in This Topic

More Rust Guides