RRust By Example

Rust AI Inference Real-World Cases

Real-world examples of AI inference in Rust: LLM serving, image classification APIs, recommendation engines, fraud detection, and NLP pipelines at production scale.

Topic: Ai Inference

Search intent: High-intent search: "rust ai inference real world examples production"

Rust AI Inference Real-World Cases

Case 1: LLM token streaming API

Problem: Serve a language model that generates tokens one by one; clients need real-time streaming responses.

rust
use tokio::sync::mpsc;
use std::time::Duration;

/// Simulated token stream from an LLM
async fn stream_tokens(
    prompt: &str,
    tx: mpsc::Sender<String>,
) {
    let words = prompt.split_whitespace().collect::<Vec<_>>();
    // Simulate autoregressive generation
    let response_words = vec!["The", "answer", "is", "42", "because", "Rust", "is", "fast", "."];

    for token in response_words {
        tokio::time::sleep(Duration::from_millis(50)).await; // Simulate compute per token
        if tx.send(token.to_string()).await.is_err() {
            break; // Client disconnected
        }
    }
}

#[tokio::main]
async fn main() {
    let (tx, mut rx) = mpsc::channel::<String>(32);
    let prompt = "Explain why Rust is good for AI";

    // Stream tokens in background
    tokio::spawn(async move {
        stream_tokens(prompt, tx).await;
    });

    // Consumer: print tokens as they arrive (SSE / WebSocket in real code)
    print!("Response: ");
    while let Some(token) = rx.recv().await {
        print!("{} ", token);
        // In production: flush to HTTP response via axum SSE
    }
    println!();
}

Metrics achieved (production example): 80 tokens/sec per instance, p99 first-token latency < 200ms, 500 concurrent streams on a single 8-core instance.

---

Case 2: Image classification REST API

Problem: Accept JPEG uploads, run a ResNet-50 classifier, return top-5 predictions.

rust
use std::collections::HashMap;

/// Simulated image classification pipeline
struct ImageClassifier {
    class_labels: HashMap<usize, String>,
}

impl ImageClassifier {
    fn new() -> Self {
        let mut labels = HashMap::new();
        labels.insert(0, "cat".to_string());
        labels.insert(1, "dog".to_string());
        labels.insert(2, "car".to_string());
        labels.insert(3, "tree".to_string());
        labels.insert(4, "bird".to_string());
        Self { class_labels: labels }
    }

    /// Preprocess image bytes → normalized f32 tensor [1, 3, 224, 224]
    fn preprocess(&self, image_bytes: &[u8]) -> Vec<f32> {
        // Real: decode JPEG, resize to 224x224, normalize with ImageNet mean/std
        // Simplified: return mock tensor
        let size = 3 * 224 * 224;
        (0..size).map(|i| (i as f32 / size as f32 * 2.0 - 1.0)).collect()
    }

    /// Run inference, return top-k class indices with scores
    fn classify(&self, tensor: &[f32]) -> Vec<(usize, f32)> {
        // Simulate softmax output
        let num_classes = self.class_labels.len();
        let mut scores: Vec<(usize, f32)> = (0..num_classes)
            .map(|i| {
                let score = tensor.iter().take(i + 1).sum::<f32>().abs() % 1.0;
                (i, score)
            })
            .collect();
        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scores
    }

    fn top_k(&self, image_bytes: &[u8], k: usize) -> Vec<(String, f32)> {
        let tensor = self.preprocess(image_bytes);
        self.classify(&tensor)
            .into_iter()
            .take(k)
            .map(|(idx, score)| (
                self.class_labels.get(&idx).cloned().unwrap_or("unknown".to_string()),
                score,
            ))
            .collect()
    }
}

fn main() {
    let classifier = ImageClassifier::new();
    let fake_image = vec![0u8; 1024]; // Simulated JPEG bytes

    let predictions = classifier.top_k(&fake_image, 3);
    println!("Top-3 predictions:");
    for (i, (label, score)) in predictions.iter().enumerate() {
        println!("  {}. {} ({:.1}%)", i + 1, label, score * 100.0);
    }
}

---

Case 3: Real-time fraud detection

Problem: Score every payment transaction in < 5ms p99. Model is a gradient-boosted tree with 500 features.

rust
use std::collections::HashMap;

/// Feature engineering for fraud detection
struct FraudFeatures {
    amount: f32,
    merchant_risk_score: f32,
    hour_of_day: u8,
    day_of_week: u8,
    velocity_1h: u32,    // transactions in last hour
    velocity_24h: u32,   // transactions in last 24h
    distance_from_home: f32, // km
    is_international: bool,
}

impl FraudFeatures {
    fn to_vec(&self) -> Vec<f32> {
        vec![
            self.amount,
            self.merchant_risk_score,
            self.hour_of_day as f32,
            self.day_of_week as f32,
            self.velocity_1h as f32,
            self.velocity_24h as f32,
            self.distance_from_home,
            if self.is_international { 1.0 } else { 0.0 },
        ]
    }
}

/// Simplified gradient-boosted tree scoring
fn fraud_score(features: &FraudFeatures) -> f32 {
    let f = features.to_vec();
    // Rule-based approximation of GBT output
    let mut score = 0.0f32;
    if f[0] > 1000.0 { score += 0.3; } // High amount
    if f[1] > 0.7 { score += 0.25; }   // High-risk merchant
    if f[4] > 5 { score += 0.2; }      // High velocity
    if f[7] > 0.5 { score += 0.15; }   // International
    if f[6] > 500.0 { score += 0.1; }  // Far from home
    score.min(1.0)
}

fn main() {
    let normal_tx = FraudFeatures {
        amount: 45.0, merchant_risk_score: 0.1, hour_of_day: 14,
        day_of_week: 2, velocity_1h: 1, velocity_24h: 3,
        distance_from_home: 5.0, is_international: false,
    };

    let suspicious_tx = FraudFeatures {
        amount: 2500.0, merchant_risk_score: 0.9, hour_of_day: 3,
        day_of_week: 6, velocity_1h: 8, velocity_24h: 20,
        distance_from_home: 8000.0, is_international: true,
    };

    println!("Normal transaction score:     {:.2}", fraud_score(&normal_tx));
    println!("Suspicious transaction score: {:.2}", fraud_score(&suspicious_tx));
}

---

Case 4: Semantic search with embeddings

Problem: Encode documents and queries as vectors, find nearest neighbors for search.

rust
/// Cosine similarity between two vectors
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm_a < 1e-8 || norm_b < 1e-8 { return 0.0; }
    dot / (norm_a * norm_b)
}

/// Simple in-memory vector store
struct VectorStore {
    documents: Vec<(String, Vec<f32>)>,
}

impl VectorStore {
    fn new() -> Self { Self { documents: Vec::new() } }

    fn add(&mut self, text: String, embedding: Vec<f32>) {
        self.documents.push((text, embedding));
    }

    fn search(&self, query_embedding: &[f32], top_k: usize) -> Vec<(&str, f32)> {
        let mut scored: Vec<(&str, f32)> = self.documents.iter()
            .map(|(text, emb)| (text.as_str(), cosine_similarity(query_embedding, emb)))
            .collect();
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scored.into_iter().take(top_k).collect()
    }
}

fn mock_embed(text: &str) -> Vec<f32> {
    // Simulate embedding: hash characters to f32 vector
    let mut v = vec![0.0f32; 8];
    for (i, c) in text.chars().enumerate() {
        v[i % 8] += c as f32 / 100.0;
    }
    // Normalize
    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
    v.iter().map(|x| x / norm).collect()
}

fn main() {
    let mut store = VectorStore::new();
    let docs = [
        "Rust is a systems programming language",
        "Python is popular for machine learning",
        "Rust has excellent async support with Tokio",
        "TensorFlow and PyTorch are deep learning frameworks",
    ];
    for doc in docs { store.add(doc.to_string(), mock_embed(doc)); }

    let query = "async programming in Rust";
    let results = store.search(&mock_embed(query), 2);
    println!("Query: '{}'", query);
    for (text, score) in results {
        println!("  [{:.3}] {}", score, text);
    }
}

Related reading

Related Guides

Continue in This Topic

More Rust Guides