RRust By Example

Rust AI Inference Maintainability

Writing maintainable AI inference code in Rust. Covers module organization, versioning, documentation patterns, configuration management, and keeping inference code testable and readable.

Topic: Ai Inference

Search intent: High-intent search: "rust ai inference maintainable code"

Rust AI Inference Maintainability

Module structure for inference projects

rust
src/
├── main.rs              — Server startup, config, signal handling
├── config.rs            — All configuration in one place
├── model/
│   ├── mod.rs           — Model trait definition
│   ├── registry.rs      — Model loading and hot-reload
│   └── versions.rs      — Version management
├── inference/
│   ├── mod.rs           — Public API
│   ├── batcher.rs       — Dynamic batching logic
│   ├── pipeline.rs      — Preprocessing + inference + postprocessing
│   └── validator.rs     — Input/output validation
├── api/
│   ├── mod.rs
│   ├── handlers.rs      — HTTP request handlers
│   └── schemas.rs       — Request/response types
└── metrics.rs           — Prometheus metrics

Runnable example — configuration management

rust
use std::time::Duration;

/// All inference configuration in one struct — no magic numbers scattered around
#[derive(Debug, Clone)]
struct InferenceConfig {
    /// Maximum number of concurrent inference requests
    max_concurrent: usize,
    /// Maximum batch size for dynamic batching
    max_batch_size: usize,
    /// How long to wait for a batch to fill before flushing
    batch_wait_timeout: Duration,
    /// Per-request inference timeout
    inference_timeout: Duration,
    /// Maximum input tensor size (elements)
    max_input_len: usize,
    /// Expected input dimensionality
    expected_input_dim: usize,
    /// Model artifact path
    model_path: String,
    /// Enable warm-up inference at startup
    enable_warmup: bool,
    /// Number of warm-up iterations
    warmup_iterations: usize,
}

impl InferenceConfig {
    fn from_env() -> Self {
        Self {
            max_concurrent: std::env::var("MAX_CONCURRENT")
                .ok().and_then(|v| v.parse().ok()).unwrap_or(16),
            max_batch_size: std::env::var("MAX_BATCH_SIZE")
                .ok().and_then(|v| v.parse().ok()).unwrap_or(32),
            batch_wait_timeout: Duration::from_millis(
                std::env::var("BATCH_WAIT_MS")
                    .ok().and_then(|v| v.parse().ok()).unwrap_or(10)
            ),
            inference_timeout: Duration::from_secs(
                std::env::var("INFERENCE_TIMEOUT_SECS")
                    .ok().and_then(|v| v.parse().ok()).unwrap_or(5)
            ),
            max_input_len: std::env::var("MAX_INPUT_LEN")
                .ok().and_then(|v| v.parse().ok()).unwrap_or(4096),
            expected_input_dim: std::env::var("INPUT_DIM")
                .ok().and_then(|v| v.parse().ok()).unwrap_or(768),
            model_path: std::env::var("MODEL_PATH")
                .unwrap_or_else(|_| "/models/default.onnx".to_string()),
            enable_warmup: std::env::var("ENABLE_WARMUP")
                .map(|v| v == "true" || v == "1").unwrap_or(true),
            warmup_iterations: 10,
        }
    }

    fn validate(&self) -> Result<(), String> {
        if self.max_concurrent == 0 {
            return Err("max_concurrent must be > 0".into());
        }
        if self.max_batch_size == 0 {
            return Err("max_batch_size must be > 0".into());
        }
        if self.expected_input_dim == 0 {
            return Err("expected_input_dim must be > 0".into());
        }
        Ok(())
    }
}

fn main() {
    let config = InferenceConfig::from_env();
    match config.validate() {
        Ok(()) => println!("Config valid: {:?}", config),
        Err(e) => eprintln!("Invalid config: {}", e),
    }
}

Versioned model trait

rust
/// All models implement this trait for uniform handling
trait InferenceModel: Send + Sync {
    /// Unique model identifier (e.g., "sentiment-bert-v2")
    fn model_id(&self) -> &str;

    /// Semantic version of model weights
    fn version(&self) -> &str;

    /// Expected input shape
    fn input_shape(&self) -> Vec<usize>;

    /// Run inference on a validated input
    fn infer(&self, input: &[f32]) -> Result<Vec<f32>, String>;

    /// Friendly description for /info endpoint
    fn description(&self) -> String {
        format!("{} ({})", self.model_id(), self.version())
    }
}

struct MockEmbeddingModel;
impl InferenceModel for MockEmbeddingModel {
    fn model_id(&self) -> &str { "text-embed-v3" }
    fn version(&self) -> &str { "3.1.0" }
    fn input_shape(&self) -> Vec<usize> { vec![1, 512] }
    fn infer(&self, input: &[f32]) -> Result<Vec<f32>, String> {
        Ok(input.iter().map(|x| x * 0.5).collect())
    }
}

fn main() {
    let model = MockEmbeddingModel;
    println!("Model: {}", model.description());
    let output = model.infer(&[1.0, 2.0, 3.0]).unwrap();
    println!("Output: {:?}", output);
}

Documenting inference pipelines

rust
/// Text classification pipeline.
///
/// # Pipeline stages
/// 1. **Tokenize**: Convert text to token IDs (max 512 tokens)
/// 2. **Embed**: Map token IDs to 768-dimensional vectors
/// 3. **Pool**: Mean-pool across token dimension → [batch, 768]
/// 4. **Classify**: Linear layer → [batch, num_classes]
/// 5. **Softmax**: Convert logits to probabilities
///
/// # Example
/// ```
/// let pipeline = TextClassifier::new(model);
/// let result = pipeline.predict("This Rust code is amazing!");
/// assert_eq!(result.label, "positive");
/// ```
struct TextClassificationPipeline {
    max_tokens: usize,
    num_classes: usize,
    class_names: Vec<String>,
}

impl TextClassificationPipeline {
    fn predict(&self, text: &str) -> ClassificationResult {
        // Stage 1: Tokenize
        let tokens = self.tokenize(text);
        // Stage 2: Mock embedding
        let embedding: Vec<f32> = tokens.iter().map(|&t| t as f32 / 1000.0).collect();
        // Stage 3: Pool
        let pooled = embedding.iter().sum::<f32>() / embedding.len() as f32;
        // Stage 4 & 5: Classify
        let scores: Vec<f32> = (0..self.num_classes)
            .map(|i| (pooled * (i + 1) as f32).sin().abs())
            .collect();
        let class = scores.iter()
            .enumerate()
            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
            .map(|(i, _)| i).unwrap_or(0);
        ClassificationResult {
            label: self.class_names[class].clone(),
            confidence: scores[class],
        }
    }

    fn tokenize(&self, text: &str) -> Vec<u32> {
        text.chars().take(self.max_tokens).map(|c| c as u32).collect()
    }
}

struct ClassificationResult { label: String, confidence: f32 }

fn main() {
    let pipeline = TextClassificationPipeline {
        max_tokens: 512, num_classes: 2,
        class_names: vec!["negative".into(), "positive".into()],
    };
    let result = pipeline.predict("Rust is great for AI!");
    println!("{} ({:.1}%)", result.label, result.confidence * 100.0);
}

Related reading

Related Guides

Continue in This Topic

More Rust Guides