Skip to main content

toak_rs/
semantic_search.rs

1//! Semantic search functionality for querying embeddings databases.
2//!
3//! This module provides tools for performing semantic similarity searches
4//! against embeddings stored in JSON format.
5
6use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::embeddings_generator::EmbeddingsGenerator;
11
12/// Represents a chunk with its embedding from the embeddings database
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct EmbeddingChunk {
15    pub file_path: String,
16    pub content: String,
17    pub embedding: Vec<f32>,
18}
19
20/// Metadata about the embeddings database
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EmbeddingsDatabaseMetadata {
23    pub version: String,
24    pub generated_at: String,
25    pub model: String,
26    pub chunk_size: usize,
27    pub overlap_size: usize,
28    pub total_files: usize,
29    pub total_chunks: usize,
30}
31
32/// The complete embeddings database structure
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct EmbeddingsDatabase {
35    pub version: String,
36    pub generated_at: String,
37    pub model: String,
38    pub chunk_size: usize,
39    pub overlap_size: usize,
40    pub total_files: usize,
41    pub total_chunks: usize,
42    pub chunks: Vec<EmbeddingChunk>,
43}
44
45/// A search result containing the chunk and its similarity score
46#[derive(Debug, Clone)]
47pub struct SearchResult {
48    pub file_path: String,
49    pub content: String,
50    pub similarity: f32,
51}
52
53/// Semantic search engine for querying embeddings databases
54pub struct SemanticSearch {
55    database: EmbeddingsDatabase,
56    generator: EmbeddingsGenerator,
57}
58
59impl SemanticSearch {
60    /// Create a new semantic search instance by loading an embeddings database
61    pub fn new<P: AsRef<Path>>(embeddings_path: P) -> Result<Self> {
62        let contents = std::fs::read_to_string(embeddings_path.as_ref())
63            .context("Failed to read embeddings file")?;
64
65        let database: EmbeddingsDatabase = serde_json::from_str(&contents)
66            .context("Failed to parse embeddings JSON")?;
67
68        let generator = EmbeddingsGenerator::new()
69            .context("Failed to initialize embeddings generator")?;
70
71        Ok(Self {
72            database,
73            generator,
74        })
75    }
76
77    /// Get metadata about the loaded database
78    pub fn metadata(&self) -> EmbeddingsDatabaseMetadata {
79        EmbeddingsDatabaseMetadata {
80            version: self.database.version.clone(),
81            generated_at: self.database.generated_at.clone(),
82            model: self.database.model.clone(),
83            chunk_size: self.database.chunk_size,
84            overlap_size: self.database.overlap_size,
85            total_files: self.database.total_files,
86            total_chunks: self.database.total_chunks,
87        }
88    }
89
90    /// Perform a semantic search with the given query
91    ///
92    /// Returns the top N results ranked by cosine similarity
93    pub fn search(&mut self, query: &str, top_n: usize) -> Result<Vec<SearchResult>> {
94        // Generate embedding for the query
95        let query_embedding = self.generator.generate_embedding(query)
96            .context("Failed to generate query embedding")?;
97
98        // Calculate similarity scores for all chunks
99        let mut results: Vec<SearchResult> = self.database.chunks
100            .iter()
101            .map(|chunk| {
102                let similarity = cosine_similarity(&query_embedding, &chunk.embedding);
103                SearchResult {
104                    file_path: chunk.file_path.clone(),
105                    content: chunk.content.clone(),
106                    similarity,
107                }
108            })
109            .collect();
110
111        // Sort by similarity (descending)
112        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap_or(std::cmp::Ordering::Equal));
113
114        // Return top N results
115        results.truncate(top_n);
116
117        Ok(results)
118    }
119
120    /// Get the total number of chunks in the database
121    pub fn chunk_count(&self) -> usize {
122        self.database.chunks.len()
123    }
124}
125
126/// Calculate cosine similarity between two vectors
127///
128/// Returns a value between -1 and 1, where 1 means identical direction,
129/// 0 means orthogonal, and -1 means opposite direction
130fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
131    if a.len() != b.len() {
132        return 0.0;
133    }
134
135    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
136    let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
137    let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
138
139    if magnitude_a == 0.0 || magnitude_b == 0.0 {
140        return 0.0;
141    }
142
143    dot_product / (magnitude_a * magnitude_b)
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_cosine_similarity_identical() {
152        let a = vec![1.0, 2.0, 3.0];
153        let b = vec![1.0, 2.0, 3.0];
154        let similarity = cosine_similarity(&a, &b);
155        assert!((similarity - 1.0).abs() < 0.0001);
156    }
157
158    #[test]
159    fn test_cosine_similarity_orthogonal() {
160        let a = vec![1.0, 0.0];
161        let b = vec![0.0, 1.0];
162        let similarity = cosine_similarity(&a, &b);
163        assert!((similarity - 0.0).abs() < 0.0001);
164    }
165
166    #[test]
167    fn test_cosine_similarity_opposite() {
168        let a = vec![1.0, 0.0];
169        let b = vec![-1.0, 0.0];
170        let similarity = cosine_similarity(&a, &b);
171        assert!((similarity - (-1.0)).abs() < 0.0001);
172    }
173
174    #[test]
175    fn test_cosine_similarity_different_lengths() {
176        let a = vec![1.0, 2.0];
177        let b = vec![1.0, 2.0, 3.0];
178        let similarity = cosine_similarity(&a, &b);
179        assert_eq!(similarity, 0.0);
180    }
181}