Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ pub trait API: Sync + Send {
/// Executes a chat request and returns a stream of responses
async fn chat(&self, chat: ChatRequest) -> Result<MpscStream<Result<ChatResponse>>>;

/// Retries the last non-droppable user message in the conversation.
async fn retry(
&self,
conversation_id: ConversationId,
) -> Result<MpscStream<Result<ChatResponse>>>;

/// Commits changes with an AI-generated commit message
async fn commit(
&self,
Expand Down
23 changes: 23 additions & 0 deletions crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,29 @@ impl<
self.app().chat(agent_id, chat).await
}

async fn retry(
&self,
conversation_id: ConversationId,
) -> anyhow::Result<MpscStream<Result<ChatResponse, anyhow::Error>>> {
let conversation = self
.services
.find_conversation(&conversation_id)
.await?
.ok_or_else(|| forge_domain::Error::ConversationNotFound(conversation_id))?;

let user_message = conversation.last_user_message()
.ok_or_else(|| anyhow::anyhow!("No user message found to retry"))?;

let event = user_message
.raw_content
.clone()
.map(Event::new)
.unwrap_or_else(|| Event::new(user_message.content.clone()));

let chat = ChatRequest::new(event, conversation_id);
self.chat(chat).await
}

async fn upsert_conversation(&self, conversation: Conversation) -> anyhow::Result<()> {
self.services.upsert_conversation(conversation).await
}
Expand Down
29 changes: 29 additions & 0 deletions crates/forge_domain/src/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ impl Conversation {
.unwrap_or_default()
}

/// Returns the last non-droppable user message in the conversation.
pub fn last_user_message(&self) -> Option<&crate::TextMessage> {
self.context.as_ref()?.messages.iter().rev().find_map(|message| {
match &**message {
crate::ContextMessage::Text(text)
if text.role == crate::Role::User && !text.droppable =>
{
Some(text)
}
_ => None,
}
})
}

/// Returns the total token usage across all messages in the conversation.
///
/// This is a convenience method that aggregates usage from the context,
Expand Down Expand Up @@ -281,6 +295,21 @@ mod tests {
assert_eq!(actual, vec![agent_conv_id]);
}

#[test]
fn test_last_user_message_returns_latest_non_droppable_message() {
let context = Context::default()
.add_message(ContextMessage::user("First", None))
.add_message(ContextMessage::assistant("Reply", None, None, None))
.add_message(ContextMessage::user("Retry me", None))
.add_message(crate::TextMessage::new(crate::Role::User, "Ignored").droppable(true));

let conversation = Conversation::generate().context(context);
let actual = conversation.last_user_message().map(|message| message.content.as_str());
let expected = Some("Retry me");

assert_eq!(actual, expected);
}

#[test]
fn test_total_cost() {
use crate::{MessageEntry, Usage};
Expand Down
31 changes: 27 additions & 4 deletions crates/forge_main/src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -905,8 +905,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
let original_id = self.state.conversation_id;
self.state.conversation_id = Some(id);

self.spinner.start(None)?;
self.on_message(None).await?;
self.on_retry(id).await?;

self.state.conversation_id = original_id;
}
Expand Down Expand Up @@ -2191,8 +2190,8 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
return self.handle_provider_logout(None).await;
}
AppCommand::Retry => {
self.spinner.start(None)?;
self.on_message(None).await?;
let conversation_id = self.init_conversation().await?;
self.on_retry(conversation_id).await?;
}
AppCommand::Index => {
let working_dir = self.state.cwd.clone();
Expand Down Expand Up @@ -3923,6 +3922,30 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
Ok(())
}

async fn on_retry(&mut self, conversation_id: ConversationId) -> Result<()> {
self.spinner.start(None)?;
self.writeln_title(TitleFormat::action("Retrying last prompt"))?;
let mut stream = self.api.retry(conversation_id).await?;
let mut writer = StreamingWriter::new(self.spinner.clone(), self.api.clone());

while let Some(message) = stream.next().await {
match message {
Ok(message) => self.handle_chat_response(message, &mut writer).await?,
Err(err) => {
writer.finish()?;
self.spinner.stop(None)?;
self.spinner.reset();
return Err(err);
}
}
}

writer.finish()?;
self.spinner.stop(None)?;
self.spinner.reset();
Ok(())
}

/// Fetches related conversations for a given conversation in parallel.
///
/// Returns a vector of related conversations that could be successfully
Expand Down