vtt_rs/
transcription.rs

1//! Transcription service and event types.
2//!
3//! This module contains the main [`TranscriptionService`] that coordinates
4//! audio capture and transcription, as well as the [`TranscriptionEvent`] enum
5//! for receiving results.
6
7use crate::audio::{encode_wav, start_audio_capture};
8use crate::config::Config;
9use anyhow::{Context, Result};
10use cpal::Stream;
11use reqwest::Client;
12use serde_json::Value;
13use std::sync::Arc;
14use tokio::fs::OpenOptions;
15use tokio::io::AsyncWriteExt;
16use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
17
18#[path = "whisper-microphone/mod.rs"]
19mod on_device;
20
21/// Events emitted by the transcription service.
22///
23/// These events are sent through the channel returned by [`TranscriptionService::start`]
24/// and represent either successful transcriptions or errors that occurred during processing.
25///
26/// # Examples
27///
28/// ```no_run
29/// use vtt_rs::{TranscriptionService, TranscriptionEvent, Config};
30///
31/// # #[tokio::main]
32/// # async fn main() -> anyhow::Result<()> {
33/// # let api_key = "test".to_string();
34/// # let config = Config::default();
35/// let mut service = TranscriptionService::new(config, api_key)?;
36/// let (mut receiver, _stream) = service.start().await?;
37///
38/// while let Some(event) = receiver.recv().await {
39///     match event {
40///         TranscriptionEvent::Transcription { chunk_id, text } => {
41///             if !text.is_empty() {
42///                 println!("[{}] {}", chunk_id, text);
43///             }
44///         }
45///         TranscriptionEvent::Error { chunk_id, error } => {
46///             eprintln!("[{}] Error: {}", chunk_id, error);
47///         }
48///     }
49/// }
50/// # Ok(())
51/// # }
52/// ```
53#[derive(Debug, Clone)]
54pub enum TranscriptionEvent {
55    /// A successful transcription.
56    ///
57    /// Contains the chunk ID and the transcribed text. An empty string
58    /// indicates that silence was detected in the audio chunk.
59    Transcription {
60        /// The chunk ID (incremental counter starting from 0)
61        chunk_id: usize,
62        /// The transcribed text (empty string for silence)
63        text: String,
64    },
65    /// An error occurred during transcription.
66    ///
67    /// This can happen due to network failures, API errors, or audio
68    /// processing issues. The chunk ID helps identify which audio segment failed.
69    Error {
70        /// The chunk ID that failed
71        chunk_id: usize,
72        /// The error message describing what went wrong
73        error: String,
74    },
75}
76
77/// The main transcription service.
78///
79/// This service manages the entire transcription pipeline: capturing audio from
80/// the system's default input device, chunking it into segments, sending those
81/// segments to an OpenAI-compatible transcription API, and emitting events with
82/// the results.
83///
84/// # Examples
85///
86/// ## Basic usage
87///
88/// ```no_run
89/// use vtt_rs::{TranscriptionService, Config};
90///
91/// # #[tokio::main]
92/// # async fn main() -> anyhow::Result<()> {
93/// let config = Config::default();
94/// let api_key = std::env::var("OPENAI_API_KEY")?;
95///
96/// let mut service = TranscriptionService::new(config, api_key)?;
97/// let (mut receiver, _stream) = service.start().await?;
98///
99/// // Process events...
100/// # Ok(())
101/// # }
102/// ```
103///
104/// ## With custom configuration
105///
106/// ```no_run
107/// use vtt_rs::{TranscriptionService, Config};
108///
109/// # #[tokio::main]
110/// # async fn main() -> anyhow::Result<()> {
111/// let config = Config {
112///     chunk_duration_secs: 3,
113///     model: "whisper-1".to_string(),
114///     endpoint: "https://api.openai.com/v1/audio/transcriptions".to_string(),
115///     out_file: None,
116/// };
117///
118/// let api_key = std::env::var("OPENAI_API_KEY")?;
119/// let mut service = TranscriptionService::new(config, api_key)?;
120/// # Ok(())
121/// # }
122/// ```
123pub struct TranscriptionService {
124    config: Config,
125    api_key: Option<String>,
126}
127
128impl TranscriptionService {
129    /// Creates a new transcription service with the specified configuration and API key.
130    ///
131    /// This doesn't start audio capture yet; call [`start`](Self::start) to begin
132    /// transcription.
133    ///
134    /// # Examples
135    ///
136    /// ```no_run
137    /// use vtt_rs::{TranscriptionService, Config};
138    ///
139    /// # fn main() -> anyhow::Result<()> {
140    /// let config = Config::default();
141    /// let api_key = std::env::var("OPENAI_API_KEY")?;
142    /// let service = TranscriptionService::new(config, api_key)?;
143    /// # Ok(())
144    /// # }
145    /// ```
146    ///
147    /// # Errors
148    ///
149    /// Currently always succeeds, but returns [`Result`] for future extensibility.
150    pub fn new(config: Config, api_key: String) -> Result<Self> {
151        Ok(Self {
152            config,
153            api_key: Some(api_key),
154        })
155    }
156
157    /// Creates a new transcription service for remote APIs without an API key.
158    ///
159    /// This is useful when connecting to a local OpenAI-compatible server that
160    /// does not require authentication (e.g. a self-hosted MLX server running a
161    /// Parakeet model).
162    pub fn new_no_api(config: Config) -> Result<Self> {
163        Ok(Self { config, api_key: None })
164    }
165
166    /// Creates a transcription service configured for on-device inference.
167    pub fn new_on_device(config: Config) -> Result<Self> {
168        Ok(Self {
169            config,
170            api_key: None,
171        })
172    }
173
174    /// Starts the transcription service and returns a receiver for events.
175    ///
176    /// This method begins capturing audio from the default input device and spawns
177    /// background tasks to process and transcribe audio chunks. Events are delivered
178    /// through the returned [`UnboundedReceiver<TranscriptionEvent>`].
179    ///
180    /// The returned [`Stream`] must be kept alive for audio capture to continue.
181    /// Dropping either the receiver or the stream will stop transcription.
182    ///
183    /// # Examples
184    ///
185    /// ```no_run
186    /// use vtt_rs::{TranscriptionService, Config, TranscriptionEvent};
187    ///
188    /// # #[tokio::main]
189    /// # async fn main() -> anyhow::Result<()> {
190    /// let config = Config::default();
191    /// let api_key = std::env::var("OPENAI_API_KEY")?;
192    ///
193    /// let mut service = TranscriptionService::new(config, api_key)?;
194    /// let (mut receiver, _stream) = service.start().await?;
195    ///
196    /// // Process transcription events
197    /// while let Some(event) = receiver.recv().await {
198    ///     match event {
199    ///         TranscriptionEvent::Transcription { chunk_id, text } => {
200    ///             println!("Chunk {}: {}", chunk_id, text);
201    ///         }
202    ///         TranscriptionEvent::Error { chunk_id, error } => {
203    ///             eprintln!("Error in chunk {}: {}", chunk_id, error);
204    ///         }
205    ///     }
206    /// }
207    /// # Ok(())
208    /// # }
209    /// ```
210    ///
211    /// # Errors
212    ///
213    /// Returns an error if:
214    /// - No default audio input device is available
215    /// - The audio device cannot be configured
216    /// - The configured output file cannot be opened or created
217    ///
218    /// # Panics
219    ///
220    /// May panic if the audio system is not properly initialized (rare).
221    pub async fn start(&mut self) -> Result<(UnboundedReceiver<TranscriptionEvent>, Stream)> {
222        let (event_tx, event_rx) = unbounded_channel::<TranscriptionEvent>();
223        let transcript_sink = if let Some(path) = &self.config.out_file {
224            let file = OpenOptions::new()
225                .create(true)
226                .append(true)
227                .open(path)
228                .await
229                .with_context(|| format!("opening output file {}", path.display()))?;
230            Some(Arc::new(tokio::sync::Mutex::new(file)))
231        } else {
232            None
233        };
234
235        if let Some(on_device_cfg) = self.config.on_device_config().cloned() {
236            let handle = tokio::runtime::Handle::current();
237            let stream = on_device::start_on_device_transcription(
238                on_device_cfg,
239                event_tx.clone(),
240                transcript_sink.clone(),
241                handle,
242            )?;
243            return Ok((event_rx, stream));
244        }
245
246        let (sample_tx, mut sample_rx) = unbounded_channel::<Vec<f32>>();
247
248        // Start audio capture for remote transcription
249        let (_stream, audio_config) = start_audio_capture(sample_tx)?;
250
251        let client = Client::new();
252        let chunk_duration_secs = self.config.chunk_duration_secs.max(1);
253        let samples_per_chunk = (audio_config.sample_rate as usize)
254            .saturating_mul(chunk_duration_secs)
255            .saturating_mul(audio_config.channels as usize);
256
257        let model = Arc::new(self.config.model.clone());
258        let endpoint = Arc::new(self.config.endpoint.clone());
259        // API key is optional to support local OpenAI-compatible servers.
260        let api_key = self.api_key.clone();
261
262        // Spawn the processing task
263        tokio::spawn(async move {
264            let mut buffer = Vec::with_capacity(samples_per_chunk * 2);
265            let mut chunk_id = 0usize;
266
267            while let Some(data) = sample_rx.recv().await {
268                buffer.extend(data);
269
270                while buffer.len() >= samples_per_chunk {
271                    let chunk_samples = buffer.drain(..samples_per_chunk).collect::<Vec<_>>();
272                    let client = client.clone();
273                    let api_key = api_key.clone();
274                    let sample_rate = audio_config.sample_rate;
275                    let channels = audio_config.channels;
276                    let current_chunk = chunk_id;
277                    let chunk_model = model.clone();
278                    let chunk_endpoint = endpoint.clone();
279                    let chunk_sink = transcript_sink.clone();
280                    let event_sender = event_tx.clone();
281
282                    tokio::spawn(async move {
283                        match transcribe_chunk(
284                            client,
285                            api_key,
286                            sample_rate,
287                            channels,
288                            chunk_samples,
289                            current_chunk,
290                            chunk_model,
291                            chunk_endpoint,
292                        )
293                        .await
294                        {
295                            Ok(text) => {
296                                // Send event
297                                let _ = event_sender.send(TranscriptionEvent::Transcription {
298                                    chunk_id: current_chunk,
299                                    text: text.clone(),
300                                });
301
302                                // Write to file if configured
303                                if let Some(writer) = chunk_sink {
304                                    let record_text = if text.is_empty() {
305                                        "<silence>"
306                                    } else {
307                                        text.as_str()
308                                    };
309                                    if let Err(err) =
310                                        append_transcript(writer, current_chunk, record_text).await
311                                    {
312                                        let _ = event_sender.send(TranscriptionEvent::Error {
313                                            chunk_id: current_chunk,
314                                            error: format!("File write failed: {err}"),
315                                        });
316                                    }
317                                }
318                            }
319                            Err(err) => {
320                                let _ = event_sender.send(TranscriptionEvent::Error {
321                                    chunk_id: current_chunk,
322                                    error: err.to_string(),
323                                });
324                            }
325                        }
326                    });
327
328                    chunk_id += 1;
329                }
330            }
331        });
332
333        Ok((event_rx, _stream))
334    }
335}
336
337async fn transcribe_chunk(
338    client: Client,
339    api_key: Option<String>,
340    sample_rate: u32,
341    channels: u16,
342    samples: Vec<f32>,
343    chunk_id: usize,
344    model: Arc<String>,
345    endpoint: Arc<String>,
346) -> Result<String> {
347    let wav = encode_wav(&samples, sample_rate, channels)?;
348    let part = reqwest::multipart::Part::bytes(wav)
349        .file_name(format!("chunk-{chunk_id}.wav"))
350        .mime_str("audio/wav")?;
351    let form = reqwest::multipart::Form::new()
352        .text("model", model.as_ref().clone())
353        .part("file", part);
354
355    let mut req = client.post(endpoint.as_str());
356    if let Some(key) = api_key.as_ref() {
357        req = req.bearer_auth(key);
358    }
359    let response = req
360        .multipart(form)
361        .send()
362        .await?
363        .error_for_status()?;
364
365    let payload: Value = response.json().await?;
366    let text = payload
367        .get("text")
368        .and_then(|v| v.as_str())
369        .map(str::trim)
370        .unwrap_or_default()
371        .to_string();
372
373    Ok(text)
374}
375
376pub(super) async fn append_transcript(
377    writer: Arc<tokio::sync::Mutex<tokio::fs::File>>,
378    chunk_id: usize,
379    text: &str,
380) -> Result<()> {
381    let mut guard = writer.lock().await;
382    let entry = format!("Chunk {chunk_id}: {text}\n");
383    guard.write_all(entry.as_bytes()).await?;
384    guard.flush().await?;
385    Ok(())
386}