RRust By Example

Rust AI Inference Testing Strategy

Comprehensive testing strategies for AI inference systems in Rust. Covers unit testing model logic, integration testing APIs, regression testing outputs, and load testing inference throughput.

Topic: Ai Inference

Search intent: High-intent search: "rust ai inference testing"

Rust AI Inference Testing Strategy

Testing pyramid for AI inference

rust
                    ┌───────────────────┐
                    │  Load tests (few)- throughput, p99, scaling
                    ├───────────────────┤
                    │  Integration tests│ - API, batching, timeouts
                    ├───────────────────┤
                    │  Unit tests (many)- preprocessing, postprocessing,
                    │                   │   numerical correctness
                    └───────────────────┘

Unit tests: numerical correctness

rust
/// Tests for model pre/post processing functions

fn normalize_input(input: &mut Vec<f32>) {
    let mean = input.iter().sum::<f32>() / input.len() as f32;
    let var = input.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / input.len() as f32;
    let std = var.sqrt().max(1e-8);
    for v in input.iter_mut() { *v = (*v - mean) / std; }
}

fn softmax(logits: &[f32]) -> Vec<f32> {
    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
    let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
    let sum: f32 = exp.iter().sum();
    exp.iter().map(|&x| x / sum).collect()
}

fn argmax(probs: &[f32]) -> usize {
    probs.iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
        .map(|(i, _)| i)
        .unwrap_or(0)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_normalize_zero_mean() {
        let mut input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
        normalize_input(&mut input);
        let mean: f32 = input.iter().sum::<f32>() / input.len() as f32;
        assert!(mean.abs() < 1e-5, "Mean should be ~0, got {}", mean);
    }

    #[test]
    fn test_normalize_unit_variance() {
        let mut input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
        normalize_input(&mut input);
        let mean: f32 = input.iter().sum::<f32>() / input.len() as f32;
        let var: f32 = input.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / input.len() as f32;
        assert!((var - 1.0).abs() < 1e-4, "Variance should be ~1, got {}", var);
    }

    #[test]
    fn test_softmax_sums_to_one() {
        let logits = vec![1.0f32, 2.0, 3.0, 4.0];
        let probs = softmax(&logits);
        let sum: f32 = probs.iter().sum();
        assert!((sum - 1.0).abs() < 1e-6, "Softmax sum: {}", sum);
    }

    #[test]
    fn test_softmax_max_class() {
        let logits = vec![0.1f32, 0.2, 5.0, 0.3]; // class 2 is dominant
        let probs = softmax(&logits);
        assert_eq!(argmax(&probs), 2);
    }

    #[test]
    fn test_softmax_handles_large_values() {
        let logits = vec![1000.0f32, 1001.0, 1002.0];
        let probs = softmax(&logits);
        assert!(probs.iter().all(|p| p.is_finite()), "Should not overflow");
        assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
    }

    #[test]
    fn test_normalize_constant_input() {
        // All same values — std is 0, should not panic
        let mut input = vec![5.0f32; 10];
        normalize_input(&mut input); // Should not divide by zero
        assert!(input.iter().all(|v| v.is_finite()));
    }
}

fn main() {
    let mut input = vec![2.0f32, 4.0, 6.0, 8.0];
    normalize_input(&mut input);
    let probs = softmax(&input);
    println!("Normalized: {:?}", &input);
    println!("Probs: {:?}", probs);
    println!("Predicted class: {}", argmax(&probs));
}

Integration tests: API and batching

rust
#[cfg(test)]
mod integration_tests {
    use std::sync::Arc;
    use std::time::Duration;
    use tokio::sync::Semaphore;

    struct FakeInferenceEngine {
        limit: Arc<Semaphore>,
    }

    impl FakeInferenceEngine {
        fn new() -> Self {
            Self { limit: Arc::new(Semaphore::new(4)) }
        }

        async fn infer(&self, input: Vec<f32>) -> Result<Vec<f32>, &'static str> {
            let _permit = tokio::time::timeout(
                Duration::from_secs(1),
                self.limit.acquire(),
            )
            .await
            .map_err(|_| "timeout")?
            .map_err(|_| "closed")?;

            Ok(input.iter().map(|x| x * 2.0).collect())
        }
    }

    #[tokio::test]
    async fn test_concurrent_inference() {
        let engine = Arc::new(FakeInferenceEngine::new());
        let handles: Vec<_> = (0..20).map(|i| {
            let e = engine.clone();
            tokio::spawn(async move {
                e.infer(vec![i as f32; 4]).await
            })
        }).collect();

        let results: Vec<_> = futures::future::join_all(handles).await;
        let successes = results.iter().filter(|r| r.as_ref().unwrap().is_ok()).count();
        assert_eq!(successes, 20, "All 20 requests should succeed");
    }

    #[tokio::test]
    async fn test_output_shape() {
        let engine = FakeInferenceEngine::new();
        let input = vec![1.0f32; 8];
        let output = engine.infer(input.clone()).await.unwrap();
        assert_eq!(output.len(), input.len(), "Output shape should match input");
    }

    #[tokio::test]
    async fn test_output_correctness() {
        let engine = FakeInferenceEngine::new();
        let input = vec![1.0f32, 2.0, 3.0];
        let output = engine.infer(input).await.unwrap();
        assert_eq!(output, vec![2.0f32, 4.0, 6.0]);
    }
}

Regression tests: output stability

rust
/// Save golden outputs and compare on each run to catch model changes
fn approx_eq(a: &[f32], b: &[f32], tolerance: f32) -> bool {
    a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() < tolerance)
}

#[cfg(test)]
mod regression {
    use super::*;

    // Golden output recorded from a known-good model version
    const GOLDEN_OUTPUT: &[f32] = &[0.2, 0.4, 0.6, 0.8];

    #[test]
    fn test_output_regression() {
        let input = vec![0.1f32, 0.2, 0.3, 0.4];
        let output: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
        assert!(
            approx_eq(&output, GOLDEN_OUTPUT, 1e-4),
            "Output changed: {:?} vs golden {:?}",
            output,
            GOLDEN_OUTPUT
        );
    }
}

Related reading

Related Guides

Continue in This Topic

More Rust Guides