RRust By Example

Rust AI Inference Scaling

Strategies for scaling AI inference services in Rust from single-node to distributed deployments. Covers horizontal scaling, load balancing, model sharding, and auto-scaling patterns.

Topic: Ai Inference

Search intent: High-intent search: "rust ai inference scaling distributed"

Rust AI Inference Scaling

Scaling dimensions

AI inference scales along three dimensions:

| Dimension | Technique | Rust tool |

|---|---|---|

| Request throughput | Horizontal scaling, load balancing | axum + reverse proxy |

| Batch efficiency | Dynamic batching, request coalescing | tokio::sync::mpsc |

| Model size | Model sharding, quantization | candle, ort |

| Latency under load | Queue management, priority routing | tokio semaphores |

Runnable example — dynamic batching with timeout

rust
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;

type BatchItem = (Vec<f32>, oneshot::Sender<Vec<f32>>);

struct DynamicBatcher {
    tx: mpsc::Sender<BatchItem>,
}

impl DynamicBatcher {
    fn new(max_batch: usize, max_wait_ms: u64) -> Self {
        let (tx, mut rx) = mpsc::channel::<BatchItem>(1024);

        tokio::spawn(async move {
            loop {
                let mut batch: Vec<BatchItem> = Vec::with_capacity(max_batch);

                // Wait for at least one item
                match rx.recv().await {
                    Some(item) => batch.push(item),
                    None => break,
                }

                // Try to fill the batch within the wait window
                let wait = Duration::from_millis(max_wait_ms);
                let deadline = tokio::time::Instant::now() + wait;
                loop {
                    if batch.len() >= max_batch { break; }
                    let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
                    if remaining.is_zero() { break; }
                    match timeout(remaining, rx.recv()).await {
                        Ok(Some(item)) => batch.push(item),
                        _ => break,
                    }
                }

                // Execute batch
                let (inputs, senders): (Vec<_>, Vec<_>) = batch
                    .into_iter()
                    .map(|(inp, tx)| (inp, tx))
                    .unzip();

                let results = run_batch_inference(inputs);

                for (sender, result) in senders.into_iter().zip(results) {
                    let _ = sender.send(result);
                }
            }
        });

        Self { tx }
    }

    async fn infer(&self, input: Vec<f32>) -> Vec<f32> {
        let (resp_tx, resp_rx) = oneshot::channel();
        self.tx.send((input, resp_tx)).await.unwrap();
        resp_rx.await.unwrap()
    }
}

fn run_batch_inference(batch: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
    // Simulate batch matrix multiply
    batch.into_iter().map(|v| v.iter().map(|x| x * 1.5).collect()).collect()
}

#[tokio::main]
async fn main() {
    let batcher = Arc::new(DynamicBatcher::new(32, 10));

    // Simulate concurrent clients
    let handles: Vec<_> = (0..20u32).map(|i| {
        let batcher = batcher.clone();
        tokio::spawn(async move {
            let input = vec![i as f32; 4];
            let result = batcher.infer(input).await;
            println!("Client {}: {:?}", i, &result[..2]);
        })
    }).collect();

    for h in handles { h.await.unwrap(); }
}

Horizontal scaling pattern

rust
/// Consistent hashing for routing requests to inference workers
/// (Simplified illustration — use a real consistent hash ring in production)
struct WorkerPool {
    workers: Vec<String>, // worker addresses
}

impl WorkerPool {
    fn route(&self, request_hash: u64) -> &str {
        let idx = (request_hash % self.workers.len() as u64) as usize;
        &self.workers[idx]
    }
}

fn hash_tenant(tenant_id: &str) -> u64 {
    use std::hash::{Hash, Hasher};
    use std::collections::hash_map::DefaultHasher;
    let mut h = DefaultHasher::new();
    tenant_id.hash(&mut h);
    h.finish()
}

fn main() {
    let pool = WorkerPool {
        workers: vec![
            "worker-1:8080".to_string(),
            "worker-2:8080".to_string(),
            "worker-3:8080".to_string(),
        ],
    };

    for tenant in ["acme", "globex", "initech"] {
        let h = hash_tenant(tenant);
        println!("Tenant {} → {}", tenant, pool.route(h));
    }
}

Auto-scaling signal: queue depth

rust
use std::sync::atomic::{AtomicUsize, Ordering};

static QUEUE_DEPTH: AtomicUsize = AtomicUsize::new(0);

fn enqueue() { QUEUE_DEPTH.fetch_add(1, Ordering::Relaxed); }
fn dequeue() { QUEUE_DEPTH.fetch_sub(1, Ordering::Relaxed); }

fn should_scale_out() -> bool {
    QUEUE_DEPTH.load(Ordering::Relaxed) > 100
}

fn main() {
    for _ in 0..150 { enqueue(); }
    if should_scale_out() {
        println!("Queue depth {}: trigger scale-out", QUEUE_DEPTH.load(Ordering::Relaxed));
    }
    for _ in 0..150 { dequeue(); }
}

Scaling checklist

  • [ ] Stateless workers — model weights loaded from shared storage at startup.
  • [ ] Health check endpoint signals readiness (model loaded + warm).
  • [ ] Queue depth metric exported to autoscaler (KEDA / HPA).
  • [ ] Graceful draining on SIGTERM before pod termination.
  • [ ] Session affinity disabled — all workers are equivalent.
  • [ ] Model version pinned per worker — no mixed versions in flight.
  • [ ] Circuit breaker to shed load when all workers are saturated.

Related reading

Related Guides

Continue in This Topic

More Rust Guides