Skip to main content

toak_rs/
json_database_generator.rs

1//! Helpers that walk a git repository, chunk the code, and persist embeddings into a JSON database.
2use crate::embeddings_generator::EmbeddingsGenerator;
3use crate::text_chunker::{chunk_text, ChunkerConfig};
4use crate::token_cleaner::clean_and_redact;
5use anyhow::Result;
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashSet;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::fs;
14use tokio::sync::{Semaphore};
15use tokio::sync::{mpsc, oneshot};
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::mpsc as std_mpsc;
18
19/// Metadata for a file chunk
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ChunkMetadata {
22    pub chunk_index: usize,
23    pub total_chunks: usize,
24    pub file_size: u64,
25    pub last_modified: Option<String>,
26    pub start_index: usize,
27    pub end_index: usize,
28}
29
30/// A chunk of file content with its embedding
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct EmbeddedChunk {
33    pub file_path: String,
34    pub content: String,
35    pub embedding: Vec<f32>,
36    pub metadata: ChunkMetadata,
37}
38
39/// A chunk staged for embedding (no vector yet)
40#[derive(Debug, Clone)]
41struct PendingChunk {
42    file_path: String,
43    content: String,
44    metadata: ChunkMetadata,
45}
46
47/// The complete JSON database structure
48#[derive(Debug, Serialize, Deserialize)]
49pub struct EmbeddingsDatabase {
50    pub version: String,
51    pub generated_at: String,
52    pub model: String,
53    pub chunk_size: usize,
54    pub overlap_size: usize,
55    pub total_files: usize,
56    pub total_chunks: usize,
57    pub chunks: Vec<EmbeddedChunk>,
58}
59
60/// Options for JSON database generation
61pub struct JsonDatabaseOptions {
62    pub dir: PathBuf,
63    pub output_file_path: PathBuf,
64    pub file_type_exclusions: HashSet<String>,
65    pub file_exclusions: Vec<String>,
66    pub verbose: bool,
67    pub chunker_config: ChunkerConfig,
68    /// Maximum number of files to process concurrently
69    pub max_concurrent_files: usize,
70    /// Number of parallel embedding workers (each maintains its own model instance)
71    pub embedding_pool_size: usize,
72    /// Optional batch size hint passed to the embedding backend
73    pub embedding_batch_size: Option<usize>,
74}
75
76impl Default for JsonDatabaseOptions {
77    fn default() -> Self {
78        // Choose a conservative default worker pool size based on available CPU cores,
79        // but cap to avoid excessive memory usage from multiple model instances.
80        let cpu_count = std::thread::available_parallelism()
81            .map(|n| n.get())
82            .unwrap_or(4);
83        let default_pool = cpu_count.min(4).max(1);
84
85        Self {
86            dir: PathBuf::from("."),
87            output_file_path: PathBuf::from("embeddings.json"),
88            file_type_exclusions: Default::default(),
89            file_exclusions: Default::default(),
90            verbose: true,
91            chunker_config: ChunkerConfig::default(),
92            max_concurrent_files: 4,
93            embedding_pool_size: default_pool,
94            embedding_batch_size: None,
95        }
96    }
97}
98
99/// Generator for creating JSON database with embeddings
100pub struct JsonDatabaseGenerator {
101    options: JsonDatabaseOptions,
102    embeddings_pool: EmbeddingPool,
103}
104
105impl JsonDatabaseGenerator {
106    /// Creates a new JSON database generator
107    pub fn new(options: JsonDatabaseOptions) -> Result<Self> {
108        // Build a pool of embedding workers that each own their model instance.
109        // Workers live on dedicated threads and communicate via channels — no mutex around the model.
110        let embeddings_pool = EmbeddingPool::new(options.embedding_pool_size)?;
111
112        Ok(Self {
113            options,
114            embeddings_pool,
115        })
116    }
117
118    /// Gets tracked files from git
119    async fn get_tracked_files(&self) -> Result<Vec<String>> {
120        self.get_tracked_files_internal().await
121    }
122
123    async fn get_tracked_files_internal(&self) -> Result<Vec<String>> {
124        // Run git ls-files
125        let output = Command::new("git")
126            .arg("ls-files")
127            .current_dir(&self.options.dir)
128            .output()?;
129
130        if !output.status.success() {
131            return Err(anyhow::anyhow!("git ls-files failed"));
132        }
133
134        let output_str = String::from_utf8(output.stdout)?;
135        let tracked_files: Vec<String> = output_str
136            .lines()
137            .filter(|line| !line.trim().is_empty())
138            .map(|s| s.to_string())
139            .collect();
140
141        if self.options.verbose {
142            println!("Total tracked files: {}", tracked_files.len());
143        }
144
145        let total_files = tracked_files.len();
146
147        // Filter by exclusions
148        let filtered_files = tracked_files
149            .into_iter()
150            .filter(|file| {
151                let path = Path::new(file);
152                let ext = path
153                    .extension()
154                    .and_then(|e| e.to_str())
155                    .map(|e| format!(".{}", e))
156                    .unwrap_or_default();
157
158                // Check if file type is excluded
159                if self.options.file_type_exclusions.contains(&ext) {
160                    return false;
161                }
162
163                // Check if file matches exclusion patterns
164                !self.matches_exclusion_patterns(file)
165            })
166            .collect::<Vec<_>>();
167
168        if self.options.verbose {
169            println!("Excluded files: {}", total_files - filtered_files.len());
170            println!(
171                "Files to process for embeddings: {}",
172                filtered_files.len()
173            );
174        }
175
176        Ok(filtered_files)
177    }
178
179    fn matches_exclusion_patterns(&self, file: &str) -> bool {
180        for pattern in &self.options.file_exclusions {
181            if self.glob_match(pattern, file) {
182                return true;
183            }
184        }
185        false
186    }
187
188    fn glob_match(&self, pattern: &str, path: &str) -> bool {
189        use regex::Regex;
190        let pattern = pattern
191            .replace("**", ".*")
192            .replace("*", "[^/]*")
193            .replace("?", "[^/]");
194        let pattern = format!("^{}$", pattern);
195
196        if let Ok(re) = Regex::new(&pattern) {
197            re.is_match(path)
198        } else {
199            false
200        }
201    }
202
203    /// Generates the JSON database with embeddings and writes it to disk.
204    pub async fn generate_database(&self) -> Result<JsonDatabaseResult> {
205        let overall_start = Instant::now();
206        let tracked_files = self.get_tracked_files().await?;
207
208        if self.options.verbose {
209            println!("Generating embeddings for {} files", tracked_files.len());
210            println!("Processing with max {} concurrent files", self.options.max_concurrent_files);
211        }
212
213        // Create a semaphore to limit concurrent file processing
214        let semaphore = Arc::new(Semaphore::new(self.options.max_concurrent_files));
215
216        // Stage chunks from files concurrently (no embedding yet)
217        let stage_start = Instant::now();
218        let mut tasks = Vec::new();
219        for (file_idx, file) in tracked_files.iter().enumerate() {
220            let absolute_path = self.options.dir.join(file);
221            let file = file.clone();
222            let semaphore = semaphore.clone();
223            let chunker_config = self.options.chunker_config.clone();
224            let verbose = self.options.verbose;
225            let total_files = tracked_files.len();
226
227            let task = tokio::spawn(async move {
228                // Acquire semaphore permit
229                let _permit = semaphore.acquire().await.unwrap();
230
231                if verbose {
232                    println!("Processing file {}/{}: {}", file_idx + 1, total_files, file);
233                }
234
235                match Self::process_file_stage_chunks(&absolute_path, &file, &chunker_config, verbose).await {
236                    Ok(chunks) => Ok(chunks),
237                    Err(e) => {
238                        if verbose {
239                            eprintln!("Error processing file {}: {}", file, e);
240                        }
241                        Err(e)
242                    }
243                }
244            });
245
246            tasks.push(task);
247        }
248
249        // Collect all pending chunks in stable order of file tasks finishing; order within file preserved by processing
250        let mut pending_chunks: Vec<PendingChunk> = Vec::new();
251        for task in tasks {
252            match task.await {
253                Ok(Ok(mut chunks)) => {
254                    pending_chunks.append(&mut chunks);
255                }
256                Ok(Err(_)) => {
257                    // Error already logged in task
258                }
259                Err(e) => {
260                    if self.options.verbose {
261                        eprintln!("Task join error: {}", e);
262                    }
263                }
264            }
265        }
266
267        let stage_elapsed = stage_start.elapsed();
268        let total_chunks_count = pending_chunks.len();
269        let staged_bytes: usize = pending_chunks.iter().map(|c| c.content.len()).sum();
270
271        if self.options.verbose {
272            let secs = stage_elapsed.as_secs_f64().max(1e-9);
273            let chunks_per_sec = total_chunks_count as f64 / secs;
274            let mb = staged_bytes as f64 / (1024.0 * 1024.0);
275            println!(
276                "[perf] Staging: files={}, chunks={}, bytes={:.2} MiB, time={:.3}s, throughput={:.1} chunks/s",
277                tracked_files.len(), total_chunks_count, mb, stage_elapsed.as_secs_f64(), chunks_per_sec
278            );
279        }
280
281        if total_chunks_count == 0 {
282            if self.options.verbose {
283                println!("No chunks produced; writing empty database.");
284            }
285            let database = EmbeddingsDatabase {
286                version: "1.0".to_string(),
287                generated_at: Utc::now().to_rfc3339(),
288                model: "EmbeddingGemma300M".to_string(),
289                chunk_size: self.options.chunker_config.chunk_size,
290                overlap_size: self.options.chunker_config.overlap_size,
291                total_files: tracked_files.len(),
292                total_chunks: 0,
293                chunks: vec![],
294            };
295            let json = serde_json::to_string_pretty(&database)?;
296            fs::write(&self.options.output_file_path, json).await?;
297            return Ok(JsonDatabaseResult { success: true, total_files: tracked_files.len(), total_chunks: 0 });
298        }
299
300        if self.options.verbose {
301            println!("Staged {} chunks; generating embeddings in global batches...", total_chunks_count);
302        }
303
304        // Build documents list
305        let documents: Vec<String> = pending_chunks.iter().map(|pc| pc.content.clone()).collect();
306
307        // Perform global batched embedding across the pool
308        let embed_start = Instant::now();
309        let backend_batch_size = self.options.embedding_batch_size;
310        let per_job_batch = 2048usize; // cross-file batch size per worker job
311        if self.options.verbose {
312            println!(
313                "[perf] Embedding config: pool_size={}, per_job_batch={}, backend_batch_size={:?}",
314                self.options.embedding_pool_size, per_job_batch, backend_batch_size
315            );
316        }
317        let embeddings = self
318            .embeddings_pool
319            .embed_many_ordered(documents, Some(per_job_batch), backend_batch_size)
320            .await?;
321        let embed_elapsed = embed_start.elapsed();
322        if self.options.verbose {
323            let secs = embed_elapsed.as_secs_f64().max(1e-9);
324            let chunks_per_sec = total_chunks_count as f64 / secs;
325            println!(
326                "[perf] Embedding: chunks={}, time={:.3}s, throughput={:.1} chunks/s",
327                total_chunks_count, embed_elapsed.as_secs_f64(), chunks_per_sec
328            );
329        }
330
331        // Zip back into embedded chunks
332        let mut all_chunks: Vec<EmbeddedChunk> = Vec::with_capacity(total_chunks_count);
333        for (i, pending) in pending_chunks.into_iter().enumerate() {
334            let embedding = embeddings.get(i)
335                .cloned()
336                .ok_or_else(|| anyhow::anyhow!("missing embedding for chunk {}", i))?;
337            all_chunks.push(EmbeddedChunk {
338                file_path: pending.file_path,
339                content: pending.content,
340                embedding,
341                metadata: pending.metadata,
342            });
343        }
344
345        if self.options.verbose {
346            println!("Total chunks generated: {}", all_chunks.len());
347        }
348
349        let database = EmbeddingsDatabase {
350            version: "1.0".to_string(),
351            generated_at: Utc::now().to_rfc3339(),
352            model: "EmbeddingGemma300M".to_string(),
353            chunk_size: self.options.chunker_config.chunk_size,
354            overlap_size: self.options.chunker_config.overlap_size,
355            total_files: tracked_files.len(),
356            total_chunks: all_chunks.len(),
357            chunks: all_chunks,
358        };
359
360        // Write to JSON file
361        let write_start = Instant::now();
362        let json = serde_json::to_string_pretty(&database)?;
363        fs::write(&self.options.output_file_path, json).await?;
364        let write_elapsed = write_start.elapsed();
365
366        if self.options.verbose {
367            println!(
368                "JSON database created at {}",
369                self.options.output_file_path.display()
370            );
371            let total_elapsed = overall_start.elapsed();
372            let stage = stage_elapsed.as_secs_f64();
373            let embed = embed_elapsed.as_secs_f64();
374            let write = write_elapsed.as_secs_f64();
375            let total = total_elapsed.as_secs_f64();
376            println!(
377                "[perf] Totals: time={:.3}s (stage={:.3}s, embed={:.3}s, write={:.3}s)",
378                total, stage, embed, write
379            );
380            if total > 0.0 {
381                println!(
382                    "[perf] Breakdown: stage={:.0}%, embed={:.0}%, write={:.0}%",
383                    (stage / total * 100.0).round(),
384                    (embed / total * 100.0).round(),
385                    (write / total * 100.0).round()
386                );
387            }
388        }
389
390        Ok(JsonDatabaseResult {
391            success: true,
392            total_files: tracked_files.len(),
393            total_chunks: database.total_chunks,
394        })
395    }
396
397    /// Processes a single file by chunking, cleaning, and generating embeddings.
398    async fn process_file_stage_chunks(
399        file_path: &Path,
400        relative_path: &str,
401        chunker_config: &ChunkerConfig,
402        verbose: bool,
403    ) -> Result<Vec<PendingChunk>> {
404        // Read file content
405        let content = fs::read_to_string(file_path).await?;
406        let content = clean_and_redact(&content);
407
408        if content.trim().is_empty() { return Ok(vec![]); }
409
410        // Get file metadata
411        let metadata = fs::metadata(file_path).await?;
412        let file_size = metadata.len();
413
414        let last_modified = metadata
415            .modified()
416            .ok()
417            .and_then(|time| {
418                let datetime: DateTime<Utc> = time.into();
419                Some(datetime.to_rfc3339())
420            });
421
422        // Chunk the file content
423        let text_chunks = chunk_text(&content, chunker_config);
424        let total_chunks = text_chunks.len();
425
426        if text_chunks.is_empty() { return Ok(vec![]); }
427
428        if verbose { println!("  - Staged {} chunks", total_chunks); }
429
430        // Build pending chunks (no embeddings yet)
431        let pending: Vec<PendingChunk> = text_chunks
432            .into_iter()
433            .map(|text_chunk| PendingChunk {
434                file_path: relative_path.to_string(),
435                content: text_chunk.content,
436                metadata: ChunkMetadata {
437                    chunk_index: text_chunk.chunk_index,
438                    total_chunks,
439                    file_size,
440                    last_modified: last_modified.clone(),
441                    start_index: text_chunk.start_index,
442                    end_index: text_chunk.end_index,
443                },
444            })
445            .collect();
446
447        Ok(pending)
448    }
449}
450
451// ================= Embedding worker pool (no global mutex) =================
452
453struct EmbeddingJob {
454    texts: Vec<String>,
455    batch_size: Option<usize>,
456    resp: oneshot::Sender<Result<Vec<Vec<f32>>>>,
457}
458
459#[derive(Clone)]
460struct EmbeddingPool(Arc<EmbeddingPoolInner>);
461
462struct EmbeddingPoolInner {
463    senders: Vec<mpsc::Sender<EmbeddingJob>>, // per-worker input queues
464    next: AtomicUsize,
465}
466
467impl EmbeddingPool {
468    fn new(pool_size: usize) -> Result<Self> {
469        let size = pool_size.max(1);
470        let mut senders = Vec::with_capacity(size);
471        let mut readiness_rxs = Vec::with_capacity(size);
472
473        for worker_id in 0..size {
474            // Increase queue capacity to reduce backpressure causing transient send failures.
475            let (tx, mut rx) = mpsc::channel::<EmbeddingJob>(32);
476            // One-shot readiness signal from worker -> pool (std mpsc so we can recv_timeout)
477            let (ready_tx, ready_rx) = std_mpsc::channel::<Result<()>>();
478            // Spawn a dedicated OS thread for the worker so heavy compute doesn't block the async runtime.
479            std::thread::spawn(move || {
480                // Initialize the model inside the worker thread.
481                let mut generator = match EmbeddingsGenerator::new() {
482                    Ok(g) => {
483                        // Signal readiness to the pool
484                        let _ = ready_tx.send(Ok(()));
485                        g
486                    }
487                    Err(e) => {
488                        // Signal initialization failure to the pool and exit
489                        let _ = ready_tx.send(Err(anyhow::anyhow!(format!(
490                            "embedding worker {} init failed: {}",
491                            worker_id, e
492                        ))));
493                        return;
494                    }
495                };
496
497                // Process jobs synchronously on this thread
498                while let Some(job) = rx.blocking_recv() {
499                    // Convert owned strings to &str slice for the backend
500                    let texts_refs: Vec<&str> = job.texts.iter().map(|s| s.as_str()).collect();
501                    // Catch panics inside the worker so callers receive a proper error instead of a dropped channel.
502                    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
503                        generator
504                            .generate_embeddings(texts_refs, job.batch_size)
505                    }))
506                    .map_err(|_| anyhow::anyhow!("embedding worker {} panicked during generate", worker_id))
507                    .and_then(|res| res.map_err(|e| anyhow::anyhow!(e)));
508
509                    let _ = job.resp.send(result);
510                }
511            });
512
513            senders.push(tx);
514            readiness_rxs.push(ready_rx);
515        }
516
517        // Await readiness for all workers with a timeout so we don't build a broken pool
518        let init_timeout_secs: u64 = std::env::var("TOAK_EMBED_INIT_TIMEOUT_SECS")
519            .ok()
520            .and_then(|s| s.parse().ok())
521            .unwrap_or(20);
522        let start_wait = Instant::now();
523        for (idx, rx) in readiness_rxs.into_iter().enumerate() {
524            match rx.recv_timeout(std::time::Duration::from_secs(init_timeout_secs)) {
525                Ok(Ok(())) => { /* ready */ }
526                Ok(Err(e)) => {
527                    return Err(anyhow::anyhow!(format!(
528                        "embedding pool init failed: worker {} not ready: {}",
529                        idx, e
530                    )));
531                }
532                Err(_) => {
533                    return Err(anyhow::anyhow!(format!(
534                        "embedding pool init timed out after {}s waiting for worker {}",
535                        init_timeout_secs, idx
536                    )));
537                }
538            }
539        }
540        let _elapsed = start_wait.elapsed();
541
542        Ok(Self(Arc::new(EmbeddingPoolInner {
543            senders,
544            next: AtomicUsize::new(0),
545        })))
546    }
547
548    async fn embed(&self, texts: Vec<String>, batch_size: Option<usize>) -> Result<Vec<Vec<f32>>> {
549        let inner = &self.0;
550        let len = inner.senders.len();
551        let idx = inner.next.fetch_add(1, Ordering::Relaxed) % len;
552        let (resp_tx, resp_rx) = oneshot::channel();
553        let job = EmbeddingJob {
554            texts,
555            batch_size,
556            resp: resp_tx,
557        };
558        inner
559            .senders[idx]
560            .send(job)
561            .await
562            .map_err(|e| anyhow::anyhow!(
563                "failed to send embedding job: {}. hint: worker may have failed to initialize; try setting ORT_DISABLE_COREML=1 to force CPU or check startup logs.",
564                e
565            ))?;
566
567        // Optional timeout to avoid hanging forever if a worker wedges.
568        let timeout_secs: u64 = std::env::var("TOAK_EMBED_TIMEOUT_SECS")
569            .ok()
570            .and_then(|s| s.parse().ok())
571            .unwrap_or(120);
572
573        match tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), resp_rx).await {
574            Ok(Ok(res)) => res,
575            Ok(Err(e)) => Err(anyhow::anyhow!("embedding worker dropped: {}", e)),
576            Err(_) => Err(anyhow::anyhow!(
577                "embedding job timed out after {}s; worker may be stalled",
578                timeout_secs
579            )),
580        }
581    }
582
583    /// Embed a large set of texts by slicing into per-job batches and
584    /// dispatching them across workers in parallel. Preserves the global order.
585    async fn embed_many_ordered(
586        &self,
587        texts: Vec<String>,
588        per_job_batch: Option<usize>,
589        batch_size: Option<usize>,
590    ) -> Result<Vec<Vec<f32>>> {
591        let total = texts.len();
592        if total == 0 { return Ok(Vec::new()); }
593
594        let job_batch = per_job_batch.unwrap_or(2048).max(1);
595        let mut starts = Vec::new();
596        let mut futures = Vec::new();
597
598        let inner = &self.0;
599        let workers = inner.senders.len().max(1);
600        let mut rr = inner.next.fetch_add(0, Ordering::Relaxed) % workers; // starting point
601
602        // Build jobs and submit round-robin
603        let mut i = 0;
604        while i < total {
605            let end = (i + job_batch).min(total);
606            let slice: Vec<String> = texts[i..end].to_vec();
607            let worker_idx = rr % workers;
608            rr = rr.wrapping_add(1);
609            // Send job synchronously so we surface send errors immediately.
610            let (resp_tx, resp_rx) = oneshot::channel();
611            let job = EmbeddingJob { texts: slice, batch_size, resp: resp_tx };
612            let sender = inner.senders[worker_idx].clone();
613            sender
614                .send(job)
615                .await
616                .map_err(|e| anyhow::anyhow!(
617                    "failed to send embedding job to worker {}: {}. hint: worker may have failed to initialize; try ORT_DISABLE_COREML=1 or check initialization logs.",
618                    worker_idx, e
619                ))?;
620            let rx = resp_rx;
621            starts.push(i);
622            futures.push(rx);
623            i = end;
624        }
625
626        let mut out: Vec<Vec<f32>> = (0..total).map(|_| Vec::new()).collect();
627
628        // Await all batches and place into the output vector
629        // Await all batches with a timeout to avoid indefinite hangs
630        let timeout_secs: u64 = std::env::var("TOAK_EMBED_TIMEOUT_SECS")
631            .ok()
632            .and_then(|s| s.parse().ok())
633            .unwrap_or(120);
634
635        for (start, rx) in starts.into_iter().zip(futures.into_iter()) {
636            let batch = match tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), rx).await {
637                Ok(Ok(res)) => res?,
638                Ok(Err(e)) => return Err(anyhow::anyhow!("embedding worker dropped: {}", e)),
639                Err(_) => return Err(anyhow::anyhow!(
640                    "embedding batch timed out after {}s; worker may be stalled",
641                    timeout_secs
642                )),
643            };
644            for (offset, emb) in batch.into_iter().enumerate() {
645                out[start + offset] = emb;
646            }
647        }
648
649        Ok(out)
650    }
651}
652
653/// Result returned after a generation run.
654#[derive(Debug, Clone)]
655pub struct JsonDatabaseResult {
656    pub success: bool,
657    pub total_files: usize,
658    pub total_chunks: usize,
659}