RRust By Example
advanced

LLM Token Streaming in Rust

Stream LLM tokens in real-time using Rust async channels, compatible with Server-Sent Events (SSE) and WebSocket APIs.

LLM Token Streaming in Rust

Implement real-time token streaming for language model responses using async channels.

Difficulty

Advanced

Code

rust
use std::time::Duration;
use tokio::sync::mpsc;

#[derive(Debug, Clone)]
enum StreamEvent {
    Token(String),
    Done { total_tokens: u32 },
    Error(String),
}

/// Simulate an LLM generating tokens one by one
async fn generate_tokens(prompt: &str, tx: mpsc::Sender<StreamEvent>) {
    let response_tokens = vec![
        "Rust", "is", "an", "excellent", "choice", "for", "AI",
        "inference", "due", "to", "its", "zero-cost", "abstractions", ".",
    ];

    let mut total = 0u32;
    for token in response_tokens {
        // Simulate per-token compute time
        tokio::time::sleep(Duration::from_millis(30)).await;

        if tx.send(StreamEvent::Token(token.to_string())).await.is_err() {
            // Client disconnected
            return;
        }
        total += 1;
    }

    let _ = tx.send(StreamEvent::Done { total_tokens: total }).await;
}

/// Streaming inference with timeout and cancellation
async fn stream_inference(
    prompt: String,
    output_tx: mpsc::Sender<StreamEvent>,
) {
    let (gen_tx, mut gen_rx) = mpsc::channel::<StreamEvent>(64);

    // Run generation in background
    let prompt_clone = prompt.clone();
    let gen_task = tokio::spawn(async move {
        generate_tokens(&prompt_clone, gen_tx).await;
    });

    // Forward with overall timeout
    let result = tokio::time::timeout(Duration::from_secs(10), async {
        while let Some(event) = gen_rx.recv().await {
            if output_tx.send(event).await.is_err() {
                break; // Client disconnected
            }
        }
    }).await;

    if result.is_err() {
        let _ = output_tx.send(StreamEvent::Error("generation timeout".to_string())).await;
        gen_task.abort();
    }
}

#[tokio::main]
async fn main() {
    let (tx, mut rx) = mpsc::channel::<StreamEvent>(64);
    let prompt = "Why is Rust good for AI?".to_string();

    // Start streaming
    tokio::spawn(async move {
        stream_inference(prompt, tx).await;
    });

    // Consumer (in production: HTTP SSE or WebSocket)
    print!("Response: ");
    let mut token_count = 0;
    while let Some(event) = rx.recv().await {
        match event {
            StreamEvent::Token(t) => {
                print!("{} ", t);
                token_count += 1;
            }
            StreamEvent::Done { total_tokens } => {
                println!("\n[Done: {} tokens generated]", total_tokens);
                break;
            }
            StreamEvent::Error(e) => {
                eprintln!("\n[Error: {}]", e);
                break;
            }
        }
    }
    println!("Client received: {} tokens", token_count);
}

Explanation

mpsc::channel decouples the token generator from the consumer. The generator produces tokens at its own pace; the consumer (HTTP SSE, WebSocket) forwards them without buffering the full response.

Key Concepts

  • StreamEvent enum handles tokens, completion, and errors uniformly
  • tokio::time::timeout prevents runaway generation
  • Channel backpressure: bounded channel slows producer if consumer is slow
  • gen_task.abort() cancels generation when client disconnects

Related Topics

Browse more examples in the ai-inference category to see production patterns.

More ai-inference Examples