AI Inference Batch Processing
Dynamic batching for AI inference in Rust using tokio channels and semaphores for high-throughput model serving.
AI Inference Batch Processing
Dynamic batching collects multiple inference requests into a single batch to maximize GPU/CPU utilization.
Difficulty
Advanced
Code
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot, Semaphore};
type Request = (Vec<f32>, oneshot::Sender<Vec<f32>>);
/// Dynamic batcher: waits up to `max_wait_ms` to fill a batch of `max_size`
struct DynamicBatcher {
tx: mpsc::Sender<Request>,
}
impl DynamicBatcher {
fn new(max_size: usize, max_wait_ms: u64, max_concurrent: usize) -> Self {
let (tx, mut rx) = mpsc::channel::<Request>(512);
let sem = Arc::new(Semaphore::new(max_concurrent));
tokio::spawn(async move {
loop {
let mut batch: Vec<Request> = Vec::with_capacity(max_size);
// Wait for first request
match rx.recv().await {
Some(r) => batch.push(r),
None => break,
}
// Fill batch within time window
let wait = Duration::from_millis(max_wait_ms);
let deadline = tokio::time::Instant::now() + wait;
while batch.len() < max_size {
let rem = deadline.saturating_duration_since(tokio::time::Instant::now());
if rem.is_zero() { break; }
match tokio::time::timeout(rem, rx.recv()).await {
Ok(Some(r)) => batch.push(r),
_ => break,
}
}
// Execute batch with concurrency limit
let sem = sem.clone();
tokio::spawn(async move {
let _permit = sem.acquire().await.unwrap();
let (inputs, senders): (Vec<_>, Vec<_>) = batch.into_iter().unzip();
// Simulate batch inference
let results = run_batch(inputs);
for (sender, result) in senders.into_iter().zip(results) {
let _ = sender.send(result);
}
});
}
});
Self { tx }
}
async fn infer(&self, input: Vec<f32>) -> Vec<f32> {
let (resp_tx, resp_rx) = oneshot::channel();
self.tx.send((input, resp_tx)).await.unwrap();
resp_rx.await.unwrap()
}
}
fn run_batch(inputs: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
inputs.into_iter()
.map(|v| v.iter().map(|x| x * 2.0 + 0.1).collect())
.collect()
}
#[tokio::main]
async fn main() {
let batcher = Arc::new(DynamicBatcher::new(16, 5, 4));
// Simulate 20 concurrent clients
let handles: Vec<_> = (0..20u32).map(|i| {
let b = batcher.clone();
tokio::spawn(async move {
let input = vec![i as f32; 3];
let result = b.infer(input.clone()).await;
println!("Client {}: {:?} → {:?}", i, &input, &result);
})
}).collect();
for h in handles { h.await.unwrap(); }
}Explanation
Dynamic batching uses a background task that collects requests until either max_size is reached or max_wait_ms elapses, then processes all requests together.
Key Concepts
oneshot::channelpairs each request with its responsempsc::channelwith bounded capacity applies backpressure- Batch fills opportunistically within the time window
Semaphorelimits concurrent batch executions
Related Topics
Browse more examples in the ai-inference category for production patterns.