Skip to content

Custom Embedders

This guide explains how to implement custom embedding providers.

The Embedder Trait

All embedders implement the Embedder trait:

use anyhow::Result;
use async_trait::async_trait;

#[async_trait]
pub trait Embedder: Send + Sync {
    async fn embed(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>>;
}

Basic Example: OpenAI Embedder

use embedcache::Embedder;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

pub struct OpenAIEmbedder {
    api_key: String,
    model: String,
    client: reqwest::Client,
}

impl OpenAIEmbedder {
    pub fn new(api_key: String, model: String) -> Self {
        Self {
            api_key,
            model,
            client: reqwest::Client::new(),
        }
    }
}

#[derive(Serialize)]
struct OpenAIRequest {
    model: String,
    input: Vec<String>,
}

#[derive(Deserialize)]
struct OpenAIResponse {
    data: Vec<EmbeddingData>,
}

#[derive(Deserialize)]
struct EmbeddingData {
    embedding: Vec<f32>,
}

#[async_trait]
impl Embedder for OpenAIEmbedder {
    async fn embed(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>> {
        let request = OpenAIRequest {
            model: self.model.clone(),
            input: chunks.to_vec(),
        };

        let response = self.client
            .post("https://api.openai.com/v1/embeddings")
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&request)
            .send()
            .await?
            .json::<OpenAIResponse>()
            .await?;

        Ok(response.data.into_iter().map(|d| d.embedding).collect())
    }
}

Batch Processing Embedder

For APIs with rate limits or batch size restrictions:

use embedcache::Embedder;
use anyhow::Result;
use async_trait::async_trait;

pub struct BatchEmbedder {
    inner: Box<dyn Embedder>,
    batch_size: usize,
}

impl BatchEmbedder {
    pub fn new(inner: Box<dyn Embedder>, batch_size: usize) -> Self {
        Self { inner, batch_size }
    }
}

#[async_trait]
impl Embedder for BatchEmbedder {
    async fn embed(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>> {
        let mut all_embeddings = Vec::new();

        for batch in chunks.chunks(self.batch_size) {
            let batch_embeddings = self.inner.embed(batch).await?;
            all_embeddings.extend(batch_embeddings);

            // Optional: Add delay between batches
            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
        }

        Ok(all_embeddings)
    }
}

Caching Embedder

Wrap any embedder with caching:

use embedcache::Embedder;
use anyhow::Result;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::RwLock;
use sha2::{Sha256, Digest};

pub struct CachingEmbedder {
    inner: Box<dyn Embedder>,
    cache: RwLock<HashMap<String, Vec<f32>>>,
}

impl CachingEmbedder {
    pub fn new(inner: Box<dyn Embedder>) -> Self {
        Self {
            inner,
            cache: RwLock::new(HashMap::new()),
        }
    }

    fn hash_text(text: &str) -> String {
        let mut hasher = Sha256::new();
        hasher.update(text.as_bytes());
        format!("{:x}", hasher.finalize())
    }
}

#[async_trait]
impl Embedder for CachingEmbedder {
    async fn embed(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>> {
        let mut results = Vec::with_capacity(chunks.len());
        let mut to_embed = Vec::new();
        let mut indices = Vec::new();

        // Check cache
        {
            let cache = self.cache.read().unwrap();
            for (i, chunk) in chunks.iter().enumerate() {
                let hash = Self::hash_text(chunk);
                if let Some(embedding) = cache.get(&hash) {
                    results.push(Some(embedding.clone()));
                } else {
                    results.push(None);
                    to_embed.push(chunk.clone());
                    indices.push(i);
                }
            }
        }

        // Embed cache misses
        if !to_embed.is_empty() {
            let new_embeddings = self.inner.embed(&to_embed).await?;

            let mut cache = self.cache.write().unwrap();
            for (chunk, embedding) in to_embed.iter().zip(new_embeddings.iter()) {
                let hash = Self::hash_text(chunk);
                cache.insert(hash, embedding.clone());
            }

            for (idx, embedding) in indices.into_iter().zip(new_embeddings) {
                results[idx] = Some(embedding);
            }
        }

        Ok(results.into_iter().map(|r| r.unwrap()).collect())
    }
}

Fallback Embedder

Use multiple embedders with fallback:

use embedcache::Embedder;
use anyhow::Result;
use async_trait::async_trait;

pub struct FallbackEmbedder {
    primary: Box<dyn Embedder>,
    fallback: Box<dyn Embedder>,
}

impl FallbackEmbedder {
    pub fn new(primary: Box<dyn Embedder>, fallback: Box<dyn Embedder>) -> Self {
        Self { primary, fallback }
    }
}

#[async_trait]
impl Embedder for FallbackEmbedder {
    async fn embed(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>> {
        match self.primary.embed(chunks).await {
            Ok(embeddings) => Ok(embeddings),
            Err(e) => {
                eprintln!("Primary embedder failed: {}, using fallback", e);
                self.fallback.embed(chunks).await
            }
        }
    }
}

Using Custom Embedders

Replace the default embedder in handlers:

use embedcache::{Embedder, InputDataText, get_default_config};

async fn embed_with_custom(
    input: InputDataText,
    embedder: &dyn Embedder,
) -> Result<Vec<Vec<f32>>, anyhow::Error> {
    let config = input.config.unwrap_or_else(get_default_config);

    // Use custom embedder
    let embeddings = embedder.embed(&input.text).await?;

    Ok(embeddings)
}

Best Practices

1. Handle Empty Input

async fn embed(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>> {
    if chunks.is_empty() {
        return Ok(vec![]);
    }
    // ... rest of implementation
}

2. Validate Dimensions

async fn embed(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>> {
    let embeddings = self.inner_embed(chunks).await?;

    // Verify all embeddings have same dimension
    if let Some(first) = embeddings.first() {
        let expected_dim = first.len();
        for (i, emb) in embeddings.iter().enumerate() {
            if emb.len() != expected_dim {
                return Err(anyhow::anyhow!(
                    "Embedding {} has {} dimensions, expected {}",
                    i, emb.len(), expected_dim
                ));
            }
        }
    }

    Ok(embeddings)
}

3. Implement Timeouts

async fn embed(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>> {
    tokio::time::timeout(
        std::time::Duration::from_secs(30),
        self.inner_embed(chunks)
    )
    .await
    .map_err(|_| anyhow::anyhow!("Embedding request timed out"))?
}

Testing Embedders

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_custom_embedder() {
        let embedder = MyCustomEmbedder::new();
        let texts = vec!["Hello, world!".to_string()];

        let embeddings = embedder.embed(&texts).await.unwrap();

        assert_eq!(embeddings.len(), 1);
        assert!(!embeddings[0].is_empty());
    }

    #[tokio::test]
    async fn test_empty_input() {
        let embedder = MyCustomEmbedder::new();
        let texts: Vec<String> = vec![];

        let embeddings = embedder.embed(&texts).await.unwrap();

        assert!(embeddings.is_empty());
    }
}