toak_rs/
semantic_search.rs1use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::embeddings_generator::EmbeddingsGenerator;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct EmbeddingChunk {
15 pub file_path: String,
16 pub content: String,
17 pub embedding: Vec<f32>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EmbeddingsDatabaseMetadata {
23 pub version: String,
24 pub generated_at: String,
25 pub model: String,
26 pub chunk_size: usize,
27 pub overlap_size: usize,
28 pub total_files: usize,
29 pub total_chunks: usize,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct EmbeddingsDatabase {
35 pub version: String,
36 pub generated_at: String,
37 pub model: String,
38 pub chunk_size: usize,
39 pub overlap_size: usize,
40 pub total_files: usize,
41 pub total_chunks: usize,
42 pub chunks: Vec<EmbeddingChunk>,
43}
44
45#[derive(Debug, Clone)]
47pub struct SearchResult {
48 pub file_path: String,
49 pub content: String,
50 pub similarity: f32,
51}
52
53pub struct SemanticSearch {
55 database: EmbeddingsDatabase,
56 generator: EmbeddingsGenerator,
57}
58
59impl SemanticSearch {
60 pub fn new<P: AsRef<Path>>(embeddings_path: P) -> Result<Self> {
62 let contents = std::fs::read_to_string(embeddings_path.as_ref())
63 .context("Failed to read embeddings file")?;
64
65 let database: EmbeddingsDatabase = serde_json::from_str(&contents)
66 .context("Failed to parse embeddings JSON")?;
67
68 let generator = EmbeddingsGenerator::new()
69 .context("Failed to initialize embeddings generator")?;
70
71 Ok(Self {
72 database,
73 generator,
74 })
75 }
76
77 pub fn metadata(&self) -> EmbeddingsDatabaseMetadata {
79 EmbeddingsDatabaseMetadata {
80 version: self.database.version.clone(),
81 generated_at: self.database.generated_at.clone(),
82 model: self.database.model.clone(),
83 chunk_size: self.database.chunk_size,
84 overlap_size: self.database.overlap_size,
85 total_files: self.database.total_files,
86 total_chunks: self.database.total_chunks,
87 }
88 }
89
90 pub fn search(&mut self, query: &str, top_n: usize) -> Result<Vec<SearchResult>> {
94 let query_embedding = self.generator.generate_embedding(query)
96 .context("Failed to generate query embedding")?;
97
98 let mut results: Vec<SearchResult> = self.database.chunks
100 .iter()
101 .map(|chunk| {
102 let similarity = cosine_similarity(&query_embedding, &chunk.embedding);
103 SearchResult {
104 file_path: chunk.file_path.clone(),
105 content: chunk.content.clone(),
106 similarity,
107 }
108 })
109 .collect();
110
111 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap_or(std::cmp::Ordering::Equal));
113
114 results.truncate(top_n);
116
117 Ok(results)
118 }
119
120 pub fn chunk_count(&self) -> usize {
122 self.database.chunks.len()
123 }
124}
125
126fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
131 if a.len() != b.len() {
132 return 0.0;
133 }
134
135 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
136 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
137 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
138
139 if magnitude_a == 0.0 || magnitude_b == 0.0 {
140 return 0.0;
141 }
142
143 dot_product / (magnitude_a * magnitude_b)
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_cosine_similarity_identical() {
152 let a = vec![1.0, 2.0, 3.0];
153 let b = vec![1.0, 2.0, 3.0];
154 let similarity = cosine_similarity(&a, &b);
155 assert!((similarity - 1.0).abs() < 0.0001);
156 }
157
158 #[test]
159 fn test_cosine_similarity_orthogonal() {
160 let a = vec![1.0, 0.0];
161 let b = vec![0.0, 1.0];
162 let similarity = cosine_similarity(&a, &b);
163 assert!((similarity - 0.0).abs() < 0.0001);
164 }
165
166 #[test]
167 fn test_cosine_similarity_opposite() {
168 let a = vec![1.0, 0.0];
169 let b = vec![-1.0, 0.0];
170 let similarity = cosine_similarity(&a, &b);
171 assert!((similarity - (-1.0)).abs() < 0.0001);
172 }
173
174 #[test]
175 fn test_cosine_similarity_different_lengths() {
176 let a = vec![1.0, 2.0];
177 let b = vec![1.0, 2.0, 3.0];
178 let similarity = cosine_similarity(&a, &b);
179 assert_eq!(similarity, 0.0);
180 }
181}