RRust By Example
advanced

AI Inference Batch Processing

Dynamic batching for AI inference in Rust using tokio channels and semaphores for high-throughput model serving.

AI Inference Batch Processing

Dynamic batching collects multiple inference requests into a single batch to maximize GPU/CPU utilization.

Difficulty

Advanced

Code

rust
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot, Semaphore};

type Request = (Vec<f32>, oneshot::Sender<Vec<f32>>);

/// Dynamic batcher: waits up to `max_wait_ms` to fill a batch of `max_size`
struct DynamicBatcher {
    tx: mpsc::Sender<Request>,
}

impl DynamicBatcher {
    fn new(max_size: usize, max_wait_ms: u64, max_concurrent: usize) -> Self {
        let (tx, mut rx) = mpsc::channel::<Request>(512);
        let sem = Arc::new(Semaphore::new(max_concurrent));

        tokio::spawn(async move {
            loop {
                let mut batch: Vec<Request> = Vec::with_capacity(max_size);

                // Wait for first request
                match rx.recv().await {
                    Some(r) => batch.push(r),
                    None => break,
                }

                // Fill batch within time window
                let wait = Duration::from_millis(max_wait_ms);
                let deadline = tokio::time::Instant::now() + wait;
                while batch.len() < max_size {
                    let rem = deadline.saturating_duration_since(tokio::time::Instant::now());
                    if rem.is_zero() { break; }
                    match tokio::time::timeout(rem, rx.recv()).await {
                        Ok(Some(r)) => batch.push(r),
                        _ => break,
                    }
                }

                // Execute batch with concurrency limit
                let sem = sem.clone();
                tokio::spawn(async move {
                    let _permit = sem.acquire().await.unwrap();
                    let (inputs, senders): (Vec<_>, Vec<_>) = batch.into_iter().unzip();

                    // Simulate batch inference
                    let results = run_batch(inputs);

                    for (sender, result) in senders.into_iter().zip(results) {
                        let _ = sender.send(result);
                    }
                });
            }
        });

        Self { tx }
    }

    async fn infer(&self, input: Vec<f32>) -> Vec<f32> {
        let (resp_tx, resp_rx) = oneshot::channel();
        self.tx.send((input, resp_tx)).await.unwrap();
        resp_rx.await.unwrap()
    }
}

fn run_batch(inputs: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
    inputs.into_iter()
        .map(|v| v.iter().map(|x| x * 2.0 + 0.1).collect())
        .collect()
}

#[tokio::main]
async fn main() {
    let batcher = Arc::new(DynamicBatcher::new(16, 5, 4));

    // Simulate 20 concurrent clients
    let handles: Vec<_> = (0..20u32).map(|i| {
        let b = batcher.clone();
        tokio::spawn(async move {
            let input = vec![i as f32; 3];
            let result = b.infer(input.clone()).await;
            println!("Client {}: {:?} → {:?}", i, &input, &result);
        })
    }).collect();

    for h in handles { h.await.unwrap(); }
}

Explanation

Dynamic batching uses a background task that collects requests until either max_size is reached or max_wait_ms elapses, then processes all requests together.

Key Concepts

  • oneshot::channel pairs each request with its response
  • mpsc::channel with bounded capacity applies backpressure
  • Batch fills opportunistically within the time window
  • Semaphore limits concurrent batch executions

Related Topics

Browse more examples in the ai-inference category for production patterns.

More ai-inference Examples