toak_rs/
embeddings_generator.rs1use fastembed::{TextEmbedding, InitOptions, EmbeddingModel};
5use anyhow::Result;
6use std::path::PathBuf;
7
8fn resolve_cache_dir() -> PathBuf {
13 if let Ok(dir) = std::env::var("FASTEMBED_CACHE_DIR") {
14 return PathBuf::from(dir);
15 }
16 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
17 PathBuf::from(home).join(".cache").join("fastembed")
18}
19
20pub struct EmbeddingsGenerator {
23 model: TextEmbedding,
24}
25
26impl EmbeddingsGenerator {
27 pub fn new() -> Result<Self> {
29 Self::with_model(EmbeddingModel::EmbeddingGemma300M)
30 }
31
32 pub fn with_model(model: EmbeddingModel) -> Result<Self> {
34 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
36 {
37 eprintln!("[perf] macOS aarch64 build detected; ONNX Runtime CoreML/Metal acceleration is enabled if available.");
40 if let Ok(val) = std::env::var("TOAK_EMBED_DEVICE") {
41 eprintln!("[perf] TOAK_EMBED_DEVICE={} (informational)", val);
42 }
43 }
44
45 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
47 let text_embedding = {
48 let try_init = |m: EmbeddingModel| {
49 TextEmbedding::try_new(
50 InitOptions::new(m)
51 .with_cache_dir(resolve_cache_dir())
52 .with_show_download_progress(true),
53 )
54 };
55 match try_init(model.clone()) {
56 Ok(ok) => {
57 let coreml_disabled = std::env::var("ORT_DISABLE_COREML").ok().unwrap_or_default();
58 if coreml_disabled == "1" {
59 eprintln!("[perf] ONNX Runtime CoreML disabled by ORT_DISABLE_COREML=1; using CPU backend.");
60 } else {
61 eprintln!("[perf] Attempting CoreML/Metal acceleration (CPU fallback if unavailable)...");
62 }
63 ok
64 }
65 Err(e) => {
66 eprintln!("[warn] fastembed initialization failed (CoreML path?): {}", e);
67 eprintln!("[warn] Retrying embeddings initialization with CPU backend (disabling CoreML).");
68 std::env::set_var("ORT_DISABLE_COREML", "1");
69 let retried = try_init(model)?;
70 eprintln!("[perf] Fallback successful: using CPU backend for embeddings.");
71 retried
72 }
73 }
74 };
75
76 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
77 let text_embedding = TextEmbedding::try_new(
78 InitOptions::new(model)
79 .with_cache_dir(resolve_cache_dir())
80 .with_show_download_progress(true),
81 )?;
82
83 Ok(Self {
84 model: text_embedding,
85 })
86 }
87
88 pub fn generate_embeddings(&mut self, texts: Vec<&str>, batch_size: Option<usize>) -> Result<Vec<Vec<f32>>> {
91 let embeddings = self.model.embed(texts, batch_size)?;
92 Ok(embeddings)
93 }
94
95 pub fn generate_embedding(&mut self, text: &str) -> Result<Vec<f32>> {
97 let embeddings = self.generate_embeddings(vec![text], None)?;
98 embeddings.into_iter().next()
99 .ok_or_else(|| anyhow::anyhow!("Failed to generate embedding"))
100 }
101}