Rust AI Inference Scaling
Strategies for scaling AI inference services in Rust from single-node to distributed deployments. Covers horizontal scaling, load balancing, model sharding, and auto-scaling patterns.
Topic: Ai Inference
Search intent: High-intent search: "rust ai inference scaling distributed"
Rust AI Inference Scaling
Scaling dimensions
AI inference scales along three dimensions:
| Dimension | Technique | Rust tool |
|---|---|---|
| Request throughput | Horizontal scaling, load balancing | axum + reverse proxy |
| Batch efficiency | Dynamic batching, request coalescing | tokio::sync::mpsc |
| Model size | Model sharding, quantization | candle, ort |
| Latency under load | Queue management, priority routing | tokio semaphores |
Runnable example — dynamic batching with timeout
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
type BatchItem = (Vec<f32>, oneshot::Sender<Vec<f32>>);
struct DynamicBatcher {
tx: mpsc::Sender<BatchItem>,
}
impl DynamicBatcher {
fn new(max_batch: usize, max_wait_ms: u64) -> Self {
let (tx, mut rx) = mpsc::channel::<BatchItem>(1024);
tokio::spawn(async move {
loop {
let mut batch: Vec<BatchItem> = Vec::with_capacity(max_batch);
// Wait for at least one item
match rx.recv().await {
Some(item) => batch.push(item),
None => break,
}
// Try to fill the batch within the wait window
let wait = Duration::from_millis(max_wait_ms);
let deadline = tokio::time::Instant::now() + wait;
loop {
if batch.len() >= max_batch { break; }
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() { break; }
match timeout(remaining, rx.recv()).await {
Ok(Some(item)) => batch.push(item),
_ => break,
}
}
// Execute batch
let (inputs, senders): (Vec<_>, Vec<_>) = batch
.into_iter()
.map(|(inp, tx)| (inp, tx))
.unzip();
let results = run_batch_inference(inputs);
for (sender, result) in senders.into_iter().zip(results) {
let _ = sender.send(result);
}
}
});
Self { tx }
}
async fn infer(&self, input: Vec<f32>) -> Vec<f32> {
let (resp_tx, resp_rx) = oneshot::channel();
self.tx.send((input, resp_tx)).await.unwrap();
resp_rx.await.unwrap()
}
}
fn run_batch_inference(batch: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
// Simulate batch matrix multiply
batch.into_iter().map(|v| v.iter().map(|x| x * 1.5).collect()).collect()
}
#[tokio::main]
async fn main() {
let batcher = Arc::new(DynamicBatcher::new(32, 10));
// Simulate concurrent clients
let handles: Vec<_> = (0..20u32).map(|i| {
let batcher = batcher.clone();
tokio::spawn(async move {
let input = vec![i as f32; 4];
let result = batcher.infer(input).await;
println!("Client {}: {:?}", i, &result[..2]);
})
}).collect();
for h in handles { h.await.unwrap(); }
}Horizontal scaling pattern
/// Consistent hashing for routing requests to inference workers
/// (Simplified illustration — use a real consistent hash ring in production)
struct WorkerPool {
workers: Vec<String>, // worker addresses
}
impl WorkerPool {
fn route(&self, request_hash: u64) -> &str {
let idx = (request_hash % self.workers.len() as u64) as usize;
&self.workers[idx]
}
}
fn hash_tenant(tenant_id: &str) -> u64 {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut h = DefaultHasher::new();
tenant_id.hash(&mut h);
h.finish()
}
fn main() {
let pool = WorkerPool {
workers: vec![
"worker-1:8080".to_string(),
"worker-2:8080".to_string(),
"worker-3:8080".to_string(),
],
};
for tenant in ["acme", "globex", "initech"] {
let h = hash_tenant(tenant);
println!("Tenant {} → {}", tenant, pool.route(h));
}
}Auto-scaling signal: queue depth
use std::sync::atomic::{AtomicUsize, Ordering};
static QUEUE_DEPTH: AtomicUsize = AtomicUsize::new(0);
fn enqueue() { QUEUE_DEPTH.fetch_add(1, Ordering::Relaxed); }
fn dequeue() { QUEUE_DEPTH.fetch_sub(1, Ordering::Relaxed); }
fn should_scale_out() -> bool {
QUEUE_DEPTH.load(Ordering::Relaxed) > 100
}
fn main() {
for _ in 0..150 { enqueue(); }
if should_scale_out() {
println!("Queue depth {}: trigger scale-out", QUEUE_DEPTH.load(Ordering::Relaxed));
}
for _ in 0..150 { dequeue(); }
}Scaling checklist
- [ ] Stateless workers — model weights loaded from shared storage at startup.
- [ ] Health check endpoint signals readiness (model loaded + warm).
- [ ] Queue depth metric exported to autoscaler (KEDA / HPA).
- [ ] Graceful draining on SIGTERM before pod termination.
- [ ] Session affinity disabled — all workers are equivalent.
- [ ] Model version pinned per worker — no mixed versions in flight.
- [ ] Circuit breaker to shed load when all workers are saturated.