Rust AI Inference Maintainability
Writing maintainable AI inference code in Rust. Covers module organization, versioning, documentation patterns, configuration management, and keeping inference code testable and readable.
Topic: Ai Inference
Search intent: High-intent search: "rust ai inference maintainable code"
Rust AI Inference Maintainability
Module structure for inference projects
src/
├── main.rs — Server startup, config, signal handling
├── config.rs — All configuration in one place
├── model/
│ ├── mod.rs — Model trait definition
│ ├── registry.rs — Model loading and hot-reload
│ └── versions.rs — Version management
├── inference/
│ ├── mod.rs — Public API
│ ├── batcher.rs — Dynamic batching logic
│ ├── pipeline.rs — Preprocessing + inference + postprocessing
│ └── validator.rs — Input/output validation
├── api/
│ ├── mod.rs
│ ├── handlers.rs — HTTP request handlers
│ └── schemas.rs — Request/response types
└── metrics.rs — Prometheus metricsRunnable example — configuration management
use std::time::Duration;
/// All inference configuration in one struct — no magic numbers scattered around
#[derive(Debug, Clone)]
struct InferenceConfig {
/// Maximum number of concurrent inference requests
max_concurrent: usize,
/// Maximum batch size for dynamic batching
max_batch_size: usize,
/// How long to wait for a batch to fill before flushing
batch_wait_timeout: Duration,
/// Per-request inference timeout
inference_timeout: Duration,
/// Maximum input tensor size (elements)
max_input_len: usize,
/// Expected input dimensionality
expected_input_dim: usize,
/// Model artifact path
model_path: String,
/// Enable warm-up inference at startup
enable_warmup: bool,
/// Number of warm-up iterations
warmup_iterations: usize,
}
impl InferenceConfig {
fn from_env() -> Self {
Self {
max_concurrent: std::env::var("MAX_CONCURRENT")
.ok().and_then(|v| v.parse().ok()).unwrap_or(16),
max_batch_size: std::env::var("MAX_BATCH_SIZE")
.ok().and_then(|v| v.parse().ok()).unwrap_or(32),
batch_wait_timeout: Duration::from_millis(
std::env::var("BATCH_WAIT_MS")
.ok().and_then(|v| v.parse().ok()).unwrap_or(10)
),
inference_timeout: Duration::from_secs(
std::env::var("INFERENCE_TIMEOUT_SECS")
.ok().and_then(|v| v.parse().ok()).unwrap_or(5)
),
max_input_len: std::env::var("MAX_INPUT_LEN")
.ok().and_then(|v| v.parse().ok()).unwrap_or(4096),
expected_input_dim: std::env::var("INPUT_DIM")
.ok().and_then(|v| v.parse().ok()).unwrap_or(768),
model_path: std::env::var("MODEL_PATH")
.unwrap_or_else(|_| "/models/default.onnx".to_string()),
enable_warmup: std::env::var("ENABLE_WARMUP")
.map(|v| v == "true" || v == "1").unwrap_or(true),
warmup_iterations: 10,
}
}
fn validate(&self) -> Result<(), String> {
if self.max_concurrent == 0 {
return Err("max_concurrent must be > 0".into());
}
if self.max_batch_size == 0 {
return Err("max_batch_size must be > 0".into());
}
if self.expected_input_dim == 0 {
return Err("expected_input_dim must be > 0".into());
}
Ok(())
}
}
fn main() {
let config = InferenceConfig::from_env();
match config.validate() {
Ok(()) => println!("Config valid: {:?}", config),
Err(e) => eprintln!("Invalid config: {}", e),
}
}Versioned model trait
/// All models implement this trait for uniform handling
trait InferenceModel: Send + Sync {
/// Unique model identifier (e.g., "sentiment-bert-v2")
fn model_id(&self) -> &str;
/// Semantic version of model weights
fn version(&self) -> &str;
/// Expected input shape
fn input_shape(&self) -> Vec<usize>;
/// Run inference on a validated input
fn infer(&self, input: &[f32]) -> Result<Vec<f32>, String>;
/// Friendly description for /info endpoint
fn description(&self) -> String {
format!("{} ({})", self.model_id(), self.version())
}
}
struct MockEmbeddingModel;
impl InferenceModel for MockEmbeddingModel {
fn model_id(&self) -> &str { "text-embed-v3" }
fn version(&self) -> &str { "3.1.0" }
fn input_shape(&self) -> Vec<usize> { vec![1, 512] }
fn infer(&self, input: &[f32]) -> Result<Vec<f32>, String> {
Ok(input.iter().map(|x| x * 0.5).collect())
}
}
fn main() {
let model = MockEmbeddingModel;
println!("Model: {}", model.description());
let output = model.infer(&[1.0, 2.0, 3.0]).unwrap();
println!("Output: {:?}", output);
}Documenting inference pipelines
/// Text classification pipeline.
///
/// # Pipeline stages
/// 1. **Tokenize**: Convert text to token IDs (max 512 tokens)
/// 2. **Embed**: Map token IDs to 768-dimensional vectors
/// 3. **Pool**: Mean-pool across token dimension → [batch, 768]
/// 4. **Classify**: Linear layer → [batch, num_classes]
/// 5. **Softmax**: Convert logits to probabilities
///
/// # Example
/// ```
/// let pipeline = TextClassifier::new(model);
/// let result = pipeline.predict("This Rust code is amazing!");
/// assert_eq!(result.label, "positive");
/// ```
struct TextClassificationPipeline {
max_tokens: usize,
num_classes: usize,
class_names: Vec<String>,
}
impl TextClassificationPipeline {
fn predict(&self, text: &str) -> ClassificationResult {
// Stage 1: Tokenize
let tokens = self.tokenize(text);
// Stage 2: Mock embedding
let embedding: Vec<f32> = tokens.iter().map(|&t| t as f32 / 1000.0).collect();
// Stage 3: Pool
let pooled = embedding.iter().sum::<f32>() / embedding.len() as f32;
// Stage 4 & 5: Classify
let scores: Vec<f32> = (0..self.num_classes)
.map(|i| (pooled * (i + 1) as f32).sin().abs())
.collect();
let class = scores.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i).unwrap_or(0);
ClassificationResult {
label: self.class_names[class].clone(),
confidence: scores[class],
}
}
fn tokenize(&self, text: &str) -> Vec<u32> {
text.chars().take(self.max_tokens).map(|c| c as u32).collect()
}
}
struct ClassificationResult { label: String, confidence: f32 }
fn main() {
let pipeline = TextClassificationPipeline {
max_tokens: 512, num_classes: 2,
class_names: vec!["negative".into(), "positive".into()],
};
let result = pipeline.predict("Rust is great for AI!");
println!("{} ({:.1}%)", result.label, result.confidence * 100.0);
}