Rust AI Inference Testing Strategy
Comprehensive testing strategies for AI inference systems in Rust. Covers unit testing model logic, integration testing APIs, regression testing outputs, and load testing inference throughput.
Topic: Ai Inference
Search intent: High-intent search: "rust ai inference testing"
Rust AI Inference Testing Strategy
Testing pyramid for AI inference
┌───────────────────┐
│ Load tests (few) │ - throughput, p99, scaling
├───────────────────┤
│ Integration tests│ - API, batching, timeouts
├───────────────────┤
│ Unit tests (many)│ - preprocessing, postprocessing,
│ │ numerical correctness
└───────────────────┘Unit tests: numerical correctness
/// Tests for model pre/post processing functions
fn normalize_input(input: &mut Vec<f32>) {
let mean = input.iter().sum::<f32>() / input.len() as f32;
let var = input.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / input.len() as f32;
let std = var.sqrt().max(1e-8);
for v in input.iter_mut() { *v = (*v - mean) / std; }
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exp.iter().sum();
exp.iter().map(|&x| x / sum).collect()
}
fn argmax(probs: &[f32]) -> usize {
probs.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_zero_mean() {
let mut input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
normalize_input(&mut input);
let mean: f32 = input.iter().sum::<f32>() / input.len() as f32;
assert!(mean.abs() < 1e-5, "Mean should be ~0, got {}", mean);
}
#[test]
fn test_normalize_unit_variance() {
let mut input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
normalize_input(&mut input);
let mean: f32 = input.iter().sum::<f32>() / input.len() as f32;
let var: f32 = input.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / input.len() as f32;
assert!((var - 1.0).abs() < 1e-4, "Variance should be ~1, got {}", var);
}
#[test]
fn test_softmax_sums_to_one() {
let logits = vec![1.0f32, 2.0, 3.0, 4.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "Softmax sum: {}", sum);
}
#[test]
fn test_softmax_max_class() {
let logits = vec![0.1f32, 0.2, 5.0, 0.3]; // class 2 is dominant
let probs = softmax(&logits);
assert_eq!(argmax(&probs), 2);
}
#[test]
fn test_softmax_handles_large_values() {
let logits = vec![1000.0f32, 1001.0, 1002.0];
let probs = softmax(&logits);
assert!(probs.iter().all(|p| p.is_finite()), "Should not overflow");
assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
}
#[test]
fn test_normalize_constant_input() {
// All same values — std is 0, should not panic
let mut input = vec![5.0f32; 10];
normalize_input(&mut input); // Should not divide by zero
assert!(input.iter().all(|v| v.is_finite()));
}
}
fn main() {
let mut input = vec![2.0f32, 4.0, 6.0, 8.0];
normalize_input(&mut input);
let probs = softmax(&input);
println!("Normalized: {:?}", &input);
println!("Probs: {:?}", probs);
println!("Predicted class: {}", argmax(&probs));
}Integration tests: API and batching
#[cfg(test)]
mod integration_tests {
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
struct FakeInferenceEngine {
limit: Arc<Semaphore>,
}
impl FakeInferenceEngine {
fn new() -> Self {
Self { limit: Arc::new(Semaphore::new(4)) }
}
async fn infer(&self, input: Vec<f32>) -> Result<Vec<f32>, &'static str> {
let _permit = tokio::time::timeout(
Duration::from_secs(1),
self.limit.acquire(),
)
.await
.map_err(|_| "timeout")?
.map_err(|_| "closed")?;
Ok(input.iter().map(|x| x * 2.0).collect())
}
}
#[tokio::test]
async fn test_concurrent_inference() {
let engine = Arc::new(FakeInferenceEngine::new());
let handles: Vec<_> = (0..20).map(|i| {
let e = engine.clone();
tokio::spawn(async move {
e.infer(vec![i as f32; 4]).await
})
}).collect();
let results: Vec<_> = futures::future::join_all(handles).await;
let successes = results.iter().filter(|r| r.as_ref().unwrap().is_ok()).count();
assert_eq!(successes, 20, "All 20 requests should succeed");
}
#[tokio::test]
async fn test_output_shape() {
let engine = FakeInferenceEngine::new();
let input = vec![1.0f32; 8];
let output = engine.infer(input.clone()).await.unwrap();
assert_eq!(output.len(), input.len(), "Output shape should match input");
}
#[tokio::test]
async fn test_output_correctness() {
let engine = FakeInferenceEngine::new();
let input = vec![1.0f32, 2.0, 3.0];
let output = engine.infer(input).await.unwrap();
assert_eq!(output, vec![2.0f32, 4.0, 6.0]);
}
}Regression tests: output stability
/// Save golden outputs and compare on each run to catch model changes
fn approx_eq(a: &[f32], b: &[f32], tolerance: f32) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() < tolerance)
}
#[cfg(test)]
mod regression {
use super::*;
// Golden output recorded from a known-good model version
const GOLDEN_OUTPUT: &[f32] = &[0.2, 0.4, 0.6, 0.8];
#[test]
fn test_output_regression() {
let input = vec![0.1f32, 0.2, 0.3, 0.4];
let output: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
assert!(
approx_eq(&output, GOLDEN_OUTPUT, 1e-4),
"Output changed: {:?} vs golden {:?}",
output,
GOLDEN_OUTPUT
);
}
}