summary refs log tree commit diff
path: root/src/components/smart_summary.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/components/smart_summary.rs')
-rw-r--r--src/components/smart_summary.rs158
1 files changed, 93 insertions, 65 deletions
diff --git a/src/components/smart_summary.rs b/src/components/smart_summary.rs
index de6eb91..e876195 100644
--- a/src/components/smart_summary.rs
+++ b/src/components/smart_summary.rs
@@ -1,10 +1,14 @@
 #![cfg(feature = "smart-summary")]
+use adw::prelude::*;
 use futures::AsyncBufReadExt;
+use gettextrs::*;
 use gio::prelude::SettingsExtManual;
+use relm4::{
+    gtk,
+    prelude::{Component, ComponentParts},
+    ComponentSender,
+};
 use soup::prelude::*;
-use adw::prelude::*;
-use gettextrs::*;
-use relm4::{gtk, prelude::{Component, ComponentParts}, ComponentSender};
 
 // All of this is incredibly minimalist.
 // This should be expanded later.
@@ -23,7 +27,7 @@ pub(crate) struct OllamaChunk {
 
 #[derive(Debug, serde::Deserialize)]
 pub(crate) struct OllamaError {
-    error: String
+    error: String,
 }
 impl std::error::Error for OllamaError {}
 impl std::fmt::Display for OllamaError {
@@ -43,12 +47,11 @@ impl From<OllamaResult> for Result<OllamaChunk, OllamaError> {
     fn from(val: OllamaResult) -> Self {
         match val {
             OllamaResult::Ok(chunk) => Ok(chunk),
-            OllamaResult::Err(err) => Err(err)
+            OllamaResult::Err(err) => Err(err),
         }
     }
 }
 
-
 #[derive(Debug, Default)]
 pub(crate) struct SmartSummaryButton {
     task: Option<relm4::JoinHandle<()>>,
@@ -65,41 +68,48 @@ impl SmartSummaryButton {
     ) {
         let settings = gio::Settings::new(crate::APPLICATION_ID);
         // We shouldn't let the user record a bad setting anyway.
-        let endpoint = glib::Uri::parse(
-            &settings.string("llm-endpoint"),
-            glib::UriFlags::NONE,
-        ).unwrap();
+        let endpoint =
+            glib::Uri::parse(&settings.string("llm-endpoint"), glib::UriFlags::NONE).unwrap();
         let model = settings.get::<String>("smart-summary-model");
         let system_prompt = settings.get::<String>("smart-summary-system-prompt");
         let prompt_prefix = settings.get::<String>("smart-summary-prompt-prefix");
         let mut prompt_suffix = settings.get::<String>("smart-summary-prompt-suffix");
 
-        let endpoint = endpoint.parse_relative("./api/generate", glib::UriFlags::NONE).unwrap();
+        let endpoint = endpoint
+            .parse_relative("./api/generate", glib::UriFlags::NONE)
+            .unwrap();
         log::debug!("endpoint: {}, model: {}", endpoint, model);
         log::debug!("system prompt: {}", system_prompt);
 
-        let msg = soup::Message::from_uri(
-            "POST",
-            &endpoint
-        );
+        let msg = soup::Message::from_uri("POST", &endpoint);
 
         if !prompt_suffix.is_empty() {
             prompt_suffix = String::from("\n\n") + &prompt_suffix;
         }
-        msg.set_request_body_from_bytes(Some("application/json"),
-            Some(&glib::Bytes::from_owned(serde_json::to_vec(&OllamaRequest {
-                model, system: system_prompt, prompt: format!("{}\n\n{}{}", prompt_prefix, text, prompt_suffix),
-            }).unwrap()))
+        msg.set_request_body_from_bytes(
+            Some("application/json"),
+            Some(&glib::Bytes::from_owned(
+                serde_json::to_vec(&OllamaRequest {
+                    model,
+                    system: system_prompt,
+                    prompt: format!("{}\n\n{}{}", prompt_prefix, text, prompt_suffix),
+                })
+                .unwrap(),
+            )),
         );
 
         let mut stream = match http.send_future(&msg, glib::Priority::DEFAULT).await {
             Ok(stream) => stream.into_async_buf_read(128),
             Err(err) => {
                 let _ = sender.send(Err(err.into()));
-                return
+                return;
             }
         };
-        log::debug!("response: {:?} ({})", msg.status(), msg.reason_phrase().unwrap_or_default());
+        log::debug!(
+            "response: {:?} ({})",
+            msg.status(),
+            msg.reason_phrase().unwrap_or_default()
+        );
         let mut buffer = Vec::with_capacity(2048);
         const DELIM: u8 = b'\n';
         loop {
@@ -107,28 +117,36 @@ impl SmartSummaryButton {
                 Ok(len) => len,
                 Err(err) => {
                     let _ = sender.send(Err(err.into()));
-                    return
+                    return;
                 }
             };
-            log::debug!("Got chunk ({} bytes): {}", len, String::from_utf8_lossy(&buffer));
-            let response: Result<OllamaResult, serde_json::Error> = serde_json::from_slice(&buffer[..len]);
+            log::debug!(
+                "Got chunk ({} bytes): {}",
+                len,
+                String::from_utf8_lossy(&buffer)
+            );
+            let response: Result<OllamaResult, serde_json::Error> =
+                serde_json::from_slice(&buffer[..len]);
             match response.map(Result::from) {
-                Ok(Ok(OllamaChunk { response: chunk, done })) => {
+                Ok(Ok(OllamaChunk {
+                    response: chunk,
+                    done,
+                })) => {
                     if !chunk.is_empty() {
                         sender.emit(Ok(chunk));
                     }
                     if done {
                         sender.emit(Ok(String::new()));
-                        return
+                        return;
                     }
-                },
+                }
                 Ok(Err(err)) => {
                     sender.emit(Err(err.into()));
-                    return
+                    return;
                 }
                 Err(err) => {
                     sender.emit(Err(err.into()));
-                    return
+                    return;
                 }
             }
             buffer.truncate(0);
@@ -146,13 +164,15 @@ pub(crate) enum Error {
     #[allow(private_interfaces)]
     Ollama(#[from] OllamaError),
     #[error("i/o error: {0}")]
-    Io(#[from] std::io::Error)
+    Io(#[from] std::io::Error),
 }
 
 #[derive(Debug)]
 pub(crate) enum Input {
-    #[doc(hidden)] ButtonPressed,
-    #[doc(hidden)] WarningAccepted,
+    #[doc(hidden)]
+    ButtonPressed,
+    #[doc(hidden)]
+    WarningAccepted,
     Text(String),
     Cancel,
 }
@@ -163,7 +183,7 @@ pub(crate) enum Output {
     Chunk(String),
     Done,
 
-    Error(Error)
+    Error(Error),
 }
 
 #[relm4::component(pub(crate))]
@@ -198,7 +218,7 @@ impl Component for SmartSummaryButton {
     fn init(
         init: Self::Init,
         root: Self::Root,
-        sender: ComponentSender<Self>
+        sender: ComponentSender<Self>,
     ) -> ComponentParts<Self> {
         let model = Self {
             http: init,
@@ -209,12 +229,7 @@ impl Component for SmartSummaryButton {
         ComponentParts { model, widgets }
     }
 
-    fn update(
-        &mut self,
-        msg: Self::Input,
-        sender: ComponentSender<Self>,
-        _root: &Self::Root
-    ) {
+    fn update(&mut self, msg: Self::Input, sender: ComponentSender<Self>, _root: &Self::Root) {
         match msg {
             Input::Cancel => {
                 self.waiting = false;
@@ -222,23 +237,29 @@ impl Component for SmartSummaryButton {
                     log::debug!("Parent component asked us to cancel.");
                     task.abort();
                 } else {
-                    log::warn!("Parent component asked us to cancel, but we're not running a task.");
+                    log::warn!(
+                        "Parent component asked us to cancel, but we're not running a task."
+                    );
                 }
-            },
+            }
             Input::ButtonPressed => {
                 let settings = gio::Settings::new(crate::APPLICATION_ID);
                 if !settings.get::<bool>("smart-summary-show-warning") {
                     self.update(Input::WarningAccepted, sender, _root)
                 } else {
                     // TODO: show warning dialog
-                    let skip_warning_checkbox = gtk::CheckButton::with_label(
-                        &gettext("Show this warning next time")
-                    );
+                    let skip_warning_checkbox =
+                        gtk::CheckButton::with_label(&gettext("Show this warning next time"));
 
-                    settings.bind(
-                        "smart-summary-show-warning",
-                        &skip_warning_checkbox, "active"
-                    ).get().set().build();
+                    settings
+                        .bind(
+                            "smart-summary-show-warning",
+                            &skip_warning_checkbox,
+                            "active",
+                        )
+                        .get()
+                        .set()
+                        .build();
 
                     let dialog = adw::AlertDialog::builder()
                         .heading(gettext("LLMs can be deceiving"))
@@ -251,44 +272,51 @@ impl Component for SmartSummaryButton {
                         .build();
                     dialog.add_responses(&[
                         ("close", &gettext("Cancel")),
-                        ("continue", &gettext("Proceed"))
+                        ("continue", &gettext("Proceed")),
                     ]);
                     dialog.choose(
                         &_root.root().unwrap(),
                         None::<&gio::Cancellable>,
                         glib::clone!(
-                            #[strong] sender,
+                            #[strong]
+                            sender,
                             move |res| if res.as_str() == "continue" {
                                 sender.input(Input::WarningAccepted);
                             }
-                        ))
+                        ),
+                    )
+                }
+            }
+            Input::WarningAccepted => {
+                if let Ok(()) = sender.output(Output::Start) {
+                    self.waiting = true;
+                    log::debug!("Requesting text to summarize from parent component...");
+                    // TODO: set timeout in case parent component never replies
+                    // This shouldn't happen, but I feel like we should handle this case.
                 }
-            },
-            Input::WarningAccepted => if let Ok(()) = sender.output(Output::Start) {
-                self.waiting = true;
-                log::debug!("Requesting text to summarize from parent component...");
-                // TODO: set timeout in case parent component never replies
-                // This shouldn't happen, but I feel like we should handle this case.
-            },
+            }
             Input::Text(text) => {
                 log::debug!("Would generate summary for the following text:\n{}", text);
 
                 log::debug!("XDG_DATA_DIRS={:?}", std::env::var("XDG_DATA_DIRS"));
                 let sender = sender.command_sender().clone();
-                relm4::spawn_local(Self::summarize(
-                    sender, self.http.clone(), text
-                ));
+                relm4::spawn_local(Self::summarize(sender, self.http.clone(), text));
             }
         }
     }
 
-    fn update_cmd(&mut self, msg: Self::CommandOutput, sender: ComponentSender<Self>, _root: &Self::Root) {
+    fn update_cmd(
+        &mut self,
+        msg: Self::CommandOutput,
+        sender: ComponentSender<Self>,
+        _root: &Self::Root,
+    ) {
         match msg {
             Ok(chunk) if chunk.is_empty() => {
                 self.task = None;
                 self.waiting = false;
                 let _ = sender.output(Output::Done);
-            },
+            }
             Err(err) => {
                 self.task = None;
                 self.waiting = false;
@@ -296,7 +324,7 @@ impl Component for SmartSummaryButton {
             }
             Ok(chunk) => {
                 let _ = sender.output(Output::Chunk(chunk));
-            },
+            }
         }
     }
 }