summary refs log tree commit diff
path: root/src/components/smart_summary.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/components/smart_summary.rs')
-rw-r--r--src/components/smart_summary.rs197
1 files changed, 171 insertions, 26 deletions
diff --git a/src/components/smart_summary.rs b/src/components/smart_summary.rs
index 9da67af..050a52c 100644
--- a/src/components/smart_summary.rs
+++ b/src/components/smart_summary.rs
@@ -1,15 +1,144 @@
+use futures::AsyncBufReadExt;
+use gio::prelude::SettingsExtManual;
+use soup::prelude::*;
 use adw::prelude::*;
 use gettextrs::*;
 use relm4::{gtk, prelude::{Component, ComponentParts}, ComponentSender};
 
+// All of this is incredibly minimalist.
+// This should be expanded later.
+#[derive(Debug, serde::Serialize)]
+struct OllamaRequest {
+    model: String,
+    prompt: String,
+    system: String,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct OllamaChunk {
+    response: String,
+    done: bool,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct OllamaError {
+    error: String
+}
+impl std::error::Error for OllamaError {}
+impl std::fmt::Display for OllamaError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        self.error.fmt(f)
+    }
+}
+
+#[derive(serde::Deserialize)]
+#[serde(untagged)]
+enum OllamaResult {
+    Ok(OllamaChunk),
+    Err(OllamaError),
+}
+
+impl From<OllamaResult> for Result<OllamaChunk, OllamaError> {
+    fn from(val: OllamaResult) -> Self {
+        match val {
+            OllamaResult::Ok(chunk) => Ok(chunk),
+            OllamaResult::Err(err) => Err(err)
+        }
+    }
+}
+
+
 #[derive(Debug, Default)]
 pub(crate) struct SmartSummaryButton {
-    busy: bool,
+    task: Option<relm4::JoinHandle<()>>,
+    waiting: bool,
+
+    http: soup::Session,
+}
+
+impl SmartSummaryButton {
+    async fn prompt_llm(
+        sender: relm4::Sender<Result<String, Error>>,
+        http: soup::Session,
+        endpoint: glib::Uri,
+        model: String,
+        system_prompt: String,
+        prompt_prefix: String,
+        mut prompt_suffix: String,
+        text: String,
+    ) {
+        let endpoint = endpoint.parse_relative("./api/generate", glib::UriFlags::NONE).unwrap();
+        log::debug!("endpoint: {}, model: {}", endpoint, model);
+        log::debug!("system prompt: {}", system_prompt);
+
+        let msg = soup::Message::from_uri(
+            "POST",
+            &endpoint
+        );
+
+        if !prompt_suffix.is_empty() {
+            prompt_suffix = String::from("\n\n") + &prompt_suffix;
+        }
+        msg.set_request_body_from_bytes(Some("application/json"),
+            Some(&glib::Bytes::from_owned(serde_json::to_vec(&OllamaRequest {
+                model, system: system_prompt, prompt: format!("{}\n\n{}{}", prompt_prefix, text, prompt_suffix),
+            }).unwrap()))
+        );
+
+        let mut stream = match http.send_future(&msg, glib::Priority::DEFAULT).await {
+            Ok(stream) => stream.into_async_buf_read(128),
+            Err(err) => {
+                let _ = sender.send(Err(err.into()));
+                return
+            }
+        };
+        log::debug!("response: {:?} ({})", msg.status(), msg.reason_phrase().unwrap_or_default());
+        let mut buffer = Vec::new();
+        const DELIM: u8 = b'\n';
+        loop {
+            let len = match stream.read_until(DELIM, &mut buffer).await {
+                Ok(len) => len,
+                Err(err) => {
+                    let _ = sender.send(Err(err.into()));
+                    return
+                }
+            };
+            log::debug!("Got chunk ({} bytes): {}", len, String::from_utf8_lossy(&buffer));
+            let response: Result<OllamaResult, serde_json::Error> = serde_json::from_slice(&buffer[..len]);
+            match response.map(Result::from) {
+                Ok(Ok(OllamaChunk { response: chunk, done })) => {
+                    if !chunk.is_empty() {
+                        sender.emit(Ok(chunk));
+                    }
+                    if done {
+                        sender.emit(Ok(String::new()));
+                        return
+                    }
+                },
+                Ok(Err(err)) => {
+                    sender.emit(Err(err.into()));
+                    return
+                }
+                Err(err) => {
+                    sender.emit(Err(err.into()));
+                    return
+                }
+            }
+            buffer.truncate(0);
+        }
+    }
 }
 
 #[derive(Debug, thiserror::Error)]
 pub(crate) enum Error {
-
+    #[error("glib error: {0}")]
+    Glib(#[from] glib::Error),
+    #[error("json error: {0}")]
+    Json(#[from] serde_json::Error),
+    #[error("ollama error: {0}")]
+    Ollama(#[from] OllamaError),
+    #[error("i/o error: {0}")]
+    Io(#[from] std::io::Error)
 }
 
 #[derive(Debug)]
@@ -33,7 +162,7 @@ impl Component for SmartSummaryButton {
     type Input = Input;
     type Output = Output;
 
-    type Init = ();
+    type Init = soup::Session;
     type CommandOutput = Result<String, Error>;
 
     view! {
@@ -42,11 +171,11 @@ impl Component for SmartSummaryButton {
         gtk::Button {
             connect_clicked => Input::ButtonPressed,
             #[watch]
-            set_sensitive: !model.busy,
+            set_sensitive: !(model.task.is_some() || model.waiting),
             // TRANSLATORS: please keep the newline and `<b>` tags
             set_tooltip_markup: Some(gettext("<b>Smart Summary</b>\nAsk a language model for a single-sentence summary.")).as_deref(),
 
-            if model.busy {
+            if model.task.is_some() || model.waiting {
                 gtk::Spinner { set_spinning: true }
             } else {
                 gtk::Label { set_markup: "✨" }
@@ -56,11 +185,14 @@ impl Component for SmartSummaryButton {
     }
 
     fn init(
-        _init: Self::Init,
+        init: Self::Init,
         root: Self::Root,
         sender: ComponentSender<Self>
     ) -> ComponentParts<Self> {
-        let model = Self::default();
+        let model = Self {
+            http: init,
+            ..Self::default()
+        };
         let widgets = view_output!();
 
         ComponentParts { model, widgets }
@@ -74,30 +206,41 @@ impl Component for SmartSummaryButton {
     ) {
         match msg {
             Input::Cancel => {
-                self.busy = false;
-                log::debug!("Parent component asked us to cancel.");
+                self.waiting = false;
+                if let Some(task) = self.task.take() {
+                    log::debug!("Parent component asked us to cancel.");
+                    task.abort();
+                } else {
+                    log::warn!("Parent component asked us to cancel, but we're not running a task.");
+                }
             },
             Input::ButtonPressed => if let Ok(()) = sender.output(Output::Start) {
-                self.busy = true;
+                self.waiting = true;
                 log::debug!("Requesting text to summarize from parent component...");
+                // TODO: set timeout in case parent component never replies
+                // This shouldn't happen, but I feel like we should handle this case.
             },
             Input::Text(text) => {
                 log::debug!("Would generate summary for the following text:\n{}", text);
 
-                sender.command(|sender, shutdown| shutdown.register(async move {
-                    tokio::time::sleep(std::time::Duration::from_millis(450)).await;
-
-                    for i in ["I'", "m ", "sorry,", " I", " am ", "unable", " to", " ", "write", " you ", "a summary.", " I", " am", " not ", "really ", "an ", "LLM."] {
-                        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
-
-                        if sender.send(Ok(i.to_string())).is_err() {
-                            return
-                        };
-                    }
-
-                    log::debug!("Done with the summary.");
-                    let _ = sender.send(Ok(String::new()));
-                }).drop_on_shutdown());
+                log::debug!("XDG_DATA_DIRS={:?}", std::env::var("XDG_DATA_DIRS"));
+                let settings = gio::Settings::new(crate::APPLICATION_ID);
+                // We shouldn't let the user record a bad setting anyway.
+                let endpoint = glib::Uri::parse(
+                    &settings.get::<String>("llm-endpoint"),
+                    glib::UriFlags::NONE,
+                ).unwrap();
+                let model = settings.get::<String>("smart-summary-model");
+                let system_prompt = settings.get::<String>("smart-summary-system-prompt");
+                let prompt_prefix = settings.get::<String>("smart-summary-prompt-prefix");
+                let prompt_suffix = settings.get::<String>("smart-summary-prompt-suffix");
+                let sender = sender.command_sender().clone();
+                relm4::spawn_local(Self::prompt_llm(
+                    sender, self.http.clone(),
+                    endpoint, model, system_prompt,
+                    prompt_prefix, prompt_suffix,
+                    text
+                ));
             }
         }
     }
@@ -105,11 +248,13 @@ impl Component for SmartSummaryButton {
     fn update_cmd(&mut self, msg: Self::CommandOutput, sender: ComponentSender<Self>, _root: &Self::Root) {
         match msg {
             Ok(chunk) if chunk.is_empty() => {
-                self.busy = false;
+                self.task = None;
+                self.waiting = false;
                 let _ = sender.output(Output::Done);
             },
             Err(err) => {
-                self.busy = false;
+                self.task = None;
+                self.waiting = false;
                 let _ = sender.output(Output::Error(err));
             }
             Ok(chunk) => {