Skip to main content

toak_rs/
embeddings_generator.rs

1//! Utilities for creating semantic embeddings via the `fastembed` crate.
2//! This module powers the embedding generation features that back the JSON database
3//! exporter and any higher level tooling.
4use fastembed::{TextEmbedding, InitOptions, EmbeddingModel};
5use anyhow::Result;
6use std::path::PathBuf;
7
8/// Resolves the model cache directory, avoiding a local `.fastembed_cache` in the working
9/// directory. Respects `FASTEMBED_CACHE_DIR` if set; otherwise falls back to
10/// `~/.cache/fastembed`. Note: `HF_HOME` is handled by fastembed's `pull_from_hf` and
11/// will override this value when set.
12fn resolve_cache_dir() -> PathBuf {
13    if let Ok(dir) = std::env::var("FASTEMBED_CACHE_DIR") {
14        return PathBuf::from(dir);
15    }
16    let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
17    PathBuf::from(home).join(".cache").join("fastembed")
18}
19
20/// A builder around `fastembed::TextEmbedding` that exposes simple helpers
21/// for generating per-text or batch embeddings.
22pub struct EmbeddingsGenerator {
23    model: TextEmbedding,
24}
25
26impl EmbeddingsGenerator {
27    /// Creates a new embeddings generator with the default model
28    pub fn new() -> Result<Self> {
29        Self::with_model(EmbeddingModel::EmbeddingGemma300M)
30    }
31
32    /// Creates a new embeddings generator with a specific model
33    pub fn with_model(model: EmbeddingModel) -> Result<Self> {
34        // Log the platform/backend hints to help validate acceleration on Apple Silicon.
35        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
36        {
37            // If built with `ort` CoreML feature, ONNX Runtime should select the CoreML EP
38            // when available, falling back to CPU otherwise. We log a hint here.
39            eprintln!("[perf] macOS aarch64 build detected; ONNX Runtime CoreML/Metal acceleration is enabled if available.");
40            if let Ok(val) = std::env::var("TOAK_EMBED_DEVICE") {
41                eprintln!("[perf] TOAK_EMBED_DEVICE={} (informational)", val);
42            }
43        }
44
45        // Try to initialize the model. On Apple Silicon, if CoreML fails, retry once with CPU.
46        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
47        let text_embedding = {
48            let try_init = |m: EmbeddingModel| {
49                TextEmbedding::try_new(
50                    InitOptions::new(m)
51                        .with_cache_dir(resolve_cache_dir())
52                        .with_show_download_progress(true),
53                )
54            };
55            match try_init(model.clone()) {
56                Ok(ok) => {
57                    let coreml_disabled = std::env::var("ORT_DISABLE_COREML").ok().unwrap_or_default();
58                    if coreml_disabled == "1" {
59                        eprintln!("[perf] ONNX Runtime CoreML disabled by ORT_DISABLE_COREML=1; using CPU backend.");
60                    } else {
61                        eprintln!("[perf] Attempting CoreML/Metal acceleration (CPU fallback if unavailable)...");
62                    }
63                    ok
64                }
65                Err(e) => {
66                    eprintln!("[warn] fastembed initialization failed (CoreML path?): {}", e);
67                    eprintln!("[warn] Retrying embeddings initialization with CPU backend (disabling CoreML).");
68                    std::env::set_var("ORT_DISABLE_COREML", "1");
69                    let retried = try_init(model)?;
70                    eprintln!("[perf] Fallback successful: using CPU backend for embeddings.");
71                    retried
72                }
73            }
74        };
75
76        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
77        let text_embedding = TextEmbedding::try_new(
78            InitOptions::new(model)
79                .with_cache_dir(resolve_cache_dir())
80                .with_show_download_progress(true),
81        )?;
82
83        Ok(Self {
84            model: text_embedding,
85        })
86    }
87
88    /// Generates embeddings for a batch of texts
89    /// The `batch_size` parameter can be used to control memory usage and throughput.
90    pub fn generate_embeddings(&mut self, texts: Vec<&str>, batch_size: Option<usize>) -> Result<Vec<Vec<f32>>> {
91        let embeddings = self.model.embed(texts, batch_size)?;
92        Ok(embeddings)
93    }
94
95    /// Generates embedding for a single text
96    pub fn generate_embedding(&mut self, text: &str) -> Result<Vec<f32>> {
97        let embeddings = self.generate_embeddings(vec![text], None)?;
98        embeddings.into_iter().next()
99            .ok_or_else(|| anyhow::anyhow!("Failed to generate embedding"))
100    }
101}