diff --git a/Cargo.toml b/Cargo.toml index 2c191ef..07adc30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ serde_json = "1.0" toml = "0.8" rmp-serde = "1" +deadpool = "0.10" cached = { version = "0.56.0", features = ["async"] } anyhow = "1.0" diff --git a/src/services/ws/stable/tts.rs b/src/services/ws/stable/tts.rs index f169cf4..579316a 100644 --- a/src/services/ws/stable/tts.rs +++ b/src/services/ws/stable/tts.rs @@ -122,84 +122,55 @@ impl TTSSession { } } -pub struct TTSSessionPool { - pub config: crate::config::TTSConfig, - pub workers: usize, - pub pool: tokio::sync::mpsc::UnboundedReceiver>, - pub tx: tokio::sync::mpsc::UnboundedSender>, +pub struct TTSManager { + config: crate::config::TTSConfig, } -impl TTSSessionPool { - pub fn new(config: crate::config::TTSConfig, workers: usize) -> Self { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - TTSSessionPool { - config, - workers, - pool: rx, - tx, - } - } +impl deadpool::managed::Manager for TTSManager { + type Type = TTSSession; + type Error = anyhow::Error; - pub async fn create_session(&self) -> anyhow::Result { + async fn create(&self) -> Result { TTSSession::new_from_config(&self.config).await } - pub async fn run_session( - id: u128, - mut session: TTSSession, - tx: tokio::sync::mpsc::UnboundedSender>, - ) -> anyhow::Result<()> { - log::info!("{} starting TTS session worker", id); - loop { - let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); - tx.send(resp_tx) - .map_err(|e| anyhow::anyhow!("send session request error: {}", e))?; - - let (text, tts_resp_tx) = resp_rx - .await - .map_err(|e| anyhow::anyhow!("receive session request error: {}", e))?; - - log::info!("{} processing TTS request: {}", id, text); - - if let Err(e) = session.synthesize(&text, &tts_resp_tx).await { - log::error!("{} TTS synthesis error: {}", id, e); - } - } + async fn recycle( + &self, + _obj: &mut TTSSession, + _metrics: &deadpool::managed::Metrics, + ) -> deadpool::managed::RecycleResult { + Ok(()) } +} + +pub struct TTSSessionPool { + pool: deadpool::managed::Pool, +} - async fn get_req_tx(&mut self) -> anyhow::Result> { - let req_tx = self - .pool - .recv() - .await - .ok_or_else(|| anyhow::anyhow!("no available tts session"))?; - Ok(req_tx) +impl TTSSessionPool { + pub fn new(config: crate::config::TTSConfig, workers: usize) -> Self { + let manager = TTSManager { config }; + let pool = deadpool::managed::Pool::builder(manager) + .max_size(workers) + .build() + .expect("Failed to create TTS session pool"); + TTSSessionPool { pool } } pub async fn run_loop(&mut self, mut rx: TTSRequestRx) -> anyhow::Result<()> { - let mut sucess_workers = 0; - for i in 0..self.workers { - match self.create_session().await { - Ok(session) => { - tokio::spawn(Self::run_session(i as u128, session, self.tx.clone())); - sucess_workers += 1; + while let Some((text, tts_resp_tx)) = rx.recv().await { + match self.pool.get().await { + Ok(mut session) => { + tokio::spawn(async move { + log::info!("Processing TTS request: {}", text); + if let Err(e) = session.synthesize(&text, &tts_resp_tx).await { + log::error!("TTS synthesis error: {}", e); + } + }); } Err(e) => { - log::error!("create tts session[{i}] error: {}", e); - continue; + log::error!("Failed to get TTS session from pool: {}", e); } - }; - } - - if sucess_workers == 0 { - return Err(anyhow::anyhow!("no available tts session worker")); - } - - while let Some(tts_req) = rx.recv().await { - let req_tx = self.get_req_tx().await?; - - if let Err(e) = req_tx.send(tts_req) { - log::error!("send tts request to session error: {}", e.0); } } Ok(())