Rust AI Inference Decision Matrix
How to choose the right AI inference framework for your Rust project. Compare candle, ort (ONNX Runtime), tch-rs (LibTorch), tract, and custom implementations across key dimensions.
Topic: Ai Inference
Search intent: High-intent search: "rust ai inference framework comparison"
Rust AI Inference Decision Matrix
Framework comparison
| Framework | Backend | Pure Rust | GPU | WASM | Maturity | Best For |
|---|---|---|---|---|---|---|
| candle | Custom / CUDA | ✅ Yes | ✅ CUDA | ✅ Yes | Medium | Hugging Face models, LLMs |
| ort | ONNX Runtime (C++) | ❌ FFI | ✅ CUDA, ROCm, TensorRT | ❌ No | High | Production, any ONNX model |
| tch-rs | LibTorch (C++) | ❌ FFI | ✅ CUDA | ❌ No | High | PyTorch model portability |
| tract | Custom Rust | ✅ Yes | ❌ CPU only | ✅ Yes | Medium | Edge/embedded, no C++ deps |
| burn | Custom / WGPU | ✅ Yes | ✅ WGPU | ✅ Yes | Early | Training + inference, pure Rust |
| fastembed-rs | ort wrapper | ❌ FFI | ✅ Via ort | ❌ No | Medium | Embedding models specifically |
Decision flowchart
Need GPU acceleration?
├── Yes, CUDA required → ort or tch-rs (most stable GPU support)
├── Yes, multi-GPU / TensorRT → ort with TensorRT EP
└── No, CPU only →
├── Need WASM deployment? → candle or tract
├── Need PyTorch .pt models? → tch-rs
├── Need ONNX models? → ort
└── Want pure Rust, no C++ deps? → tract or burn
Running LLMs (GPT, LLaMA, Mistral)?
└── candle (native support for transformer architectures)
Edge/embedded deployment?
└── tract (smallest binary, no system libs required)
Production serving with SLA?
└── ort (most battle-tested, TensorRT EP for maximum throughput)Runnable example — candle-style inference sketch
// This illustrates the candle API pattern
// In real code, add: candle-core = "0.7" to Cargo.toml
/// Simulated candle-style tensor operations
struct Tensor {
data: Vec<f32>,
shape: Vec<usize>,
}
impl Tensor {
fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
let expected: usize = shape.iter().product();
assert_eq!(data.len(), expected, "data length must match shape product");
Self { data, shape }
}
fn matmul(&self, other: &Tensor) -> Tensor {
// Simplified 2D matmul
assert_eq!(self.shape.len(), 2);
assert_eq!(other.shape.len(), 2);
let (m, k) = (self.shape[0], self.shape[1]);
let (k2, n) = (other.shape[0], other.shape[1]);
assert_eq!(k, k2);
let mut out = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
for p in 0..k {
out[i * n + j] += self.data[i * k + p] * other.data[p * n + j];
}
}
}
Tensor::new(out, vec![m, n])
}
fn relu(&self) -> Tensor {
Tensor::new(
self.data.iter().map(|&x| x.max(0.0)).collect(),
self.shape.clone(),
)
}
fn softmax(&self) -> Tensor {
let max = self.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = self.data.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exp.iter().sum();
Tensor::new(exp.iter().map(|&x| x / sum).collect(), self.shape.clone())
}
}
/// Simple 2-layer MLP: input → hidden → output
struct MLP {
w1: Tensor, // [input_dim, hidden_dim]
w2: Tensor, // [hidden_dim, output_dim]
}
impl MLP {
fn forward(&self, input: &Tensor) -> Tensor {
input.matmul(&self.w1).relu().matmul(&self.w2).softmax()
}
}
fn main() {
let mlp = MLP {
w1: Tensor::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], vec![3, 2]),
w2: Tensor::new(vec![0.5, 0.5, 0.5, 0.5], vec![2, 2]),
};
let input = Tensor::new(vec![1.0, 0.5, -0.3], vec![1, 3]);
let output = mlp.forward(&input);
println!("MLP output: {:?}", output.data);
println!("Probabilities sum: {:.4}", output.data.iter().sum::<f32>());
}ONNX Runtime (ort) quickstart
# Cargo.toml
[dependencies]
ort = "2.0"
ndarray = "0.15"// Example ort usage pattern (illustrative)
// use ort::{Environment, Session, SessionBuilder, Value};
// use ndarray::Array2;
//
// fn run_onnx_model(model_path: &str, input: Array2<f32>) -> Array2<f32> {
// let env = Environment::builder().build().unwrap();
// let session = SessionBuilder::new(&env)
// .unwrap()
// .with_model_from_file(model_path)
// .unwrap();
//
// let input_value = Value::from_array(session.allocator(), &input).unwrap();
// let outputs = session.run(vec![input_value]).unwrap();
// outputs[0].try_extract::<f32>().unwrap().view().to_owned()
// }When to build custom inference
Build custom inference (no framework) when:
- Model architecture is extremely simple (linear regression, small MLP).
- Deployment target has no OS support for C++ runtimes.
- Total binary size must be under 1MB.
- You need exact control over numerical precision.