about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/state/messages/convo/agent.ts42
-rw-r--r--src/state/messages/events/agent.ts103
-rw-r--r--src/state/messages/events/types.ts27
3 files changed, 85 insertions, 87 deletions
diff --git a/src/state/messages/convo/agent.ts b/src/state/messages/convo/agent.ts
index 65470baa5..79406d155 100644
--- a/src/state/messages/convo/agent.ts
+++ b/src/state/messages/convo/agent.ts
@@ -107,12 +107,6 @@ export class Convo {
     } else {
       DEBUG_ACTIVE_CHAT = this.convoId
     }
-
-    this.events.trailConvo(this.convoId, events => {
-      this.ingestFirehose(events)
-    })
-    this.events.onConnect(this.onFirehoseConnect)
-    this.events.onError(this.onFirehoseError)
   }
 
   private commit() {
@@ -211,6 +205,7 @@ export class Convo {
           case ConvoDispatchEvent.Init: {
             this.status = ConvoStatus.Initializing
             this.setup()
+            this.setupFirehose()
             this.requestPollInterval(ACTIVE_POLL_INTERVAL)
             break
           }
@@ -232,12 +227,14 @@ export class Convo {
           }
           case ConvoDispatchEvent.Suspend: {
             this.status = ConvoStatus.Suspended
+            this.cleanupFirehoseConnection?.()
             this.withdrawRequestedPollInterval()
             break
           }
           case ConvoDispatchEvent.Error: {
             this.status = ConvoStatus.Error
             this.error = action.payload
+            this.cleanupFirehoseConnection?.()
             this.withdrawRequestedPollInterval()
             break
           }
@@ -258,12 +255,14 @@ export class Convo {
           }
           case ConvoDispatchEvent.Suspend: {
             this.status = ConvoStatus.Suspended
+            this.cleanupFirehoseConnection?.()
             this.withdrawRequestedPollInterval()
             break
           }
           case ConvoDispatchEvent.Error: {
             this.status = ConvoStatus.Error
             this.error = action.payload
+            this.cleanupFirehoseConnection?.()
             this.withdrawRequestedPollInterval()
             break
           }
@@ -286,12 +285,14 @@ export class Convo {
           }
           case ConvoDispatchEvent.Suspend: {
             this.status = ConvoStatus.Suspended
+            this.cleanupFirehoseConnection?.()
             this.withdrawRequestedPollInterval()
             break
           }
           case ConvoDispatchEvent.Error: {
             this.status = ConvoStatus.Error
             this.error = action.payload
+            this.cleanupFirehoseConnection?.()
             this.withdrawRequestedPollInterval()
             break
           }
@@ -601,6 +602,33 @@ export class Convo {
     }
   }
 
+  private cleanupFirehoseConnection: (() => void) | undefined
+  private setupFirehose() {
+    // remove old listeners, if exist
+    this.cleanupFirehoseConnection?.()
+
+    // reconnect
+    this.cleanupFirehoseConnection = this.events.on(
+      event => {
+        switch (event.type) {
+          case 'connect': {
+            this.onFirehoseConnect()
+            break
+          }
+          case 'error': {
+            this.onFirehoseError(event.error)
+            break
+          }
+          case 'logs': {
+            this.ingestFirehose(event.logs)
+            break
+          }
+        }
+      },
+      {convoId: this.convoId},
+    )
+  }
+
   onFirehoseConnect() {
     this.footerItems.delete(ConvoItemError.FirehoseFailed)
     this.commit()
@@ -709,6 +737,8 @@ export class Convo {
       id: tempId,
       message,
     })
+    // remove on each send, it might go through now without user having to click
+    this.footerItems.delete(ConvoItemError.PendingFailed)
     this.commit()
 
     if (!this.isProcessingPendingMessages) {
diff --git a/src/state/messages/events/agent.ts b/src/state/messages/events/agent.ts
index 061337d3b..68225e595 100644
--- a/src/state/messages/events/agent.ts
+++ b/src/state/messages/events/agent.ts
@@ -8,9 +8,8 @@ import {DEFAULT_POLL_INTERVAL} from '#/state/messages/events/const'
 import {
   MessagesEventBusDispatch,
   MessagesEventBusDispatchEvent,
-  MessagesEventBusError,
   MessagesEventBusErrorCode,
-  MessagesEventBusEvents,
+  MessagesEventBusEvent,
   MessagesEventBusParams,
   MessagesEventBusStatus,
 } from '#/state/messages/events/types'
@@ -22,10 +21,9 @@ export class MessagesEventBus {
 
   private agent: BskyAgent
   private __tempFromUserDid: string
-  private emitter = new EventEmitter<MessagesEventBusEvents>()
+  private emitter = new EventEmitter<{event: [MessagesEventBusEvent]}>()
 
   private status: MessagesEventBusStatus = MessagesEventBusStatus.Initializing
-  private error: MessagesEventBusError | undefined
   private latestRev: string | undefined = undefined
   private pollInterval = DEFAULT_POLL_INTERVAL
   private requestedPollIntervals: Map<string, number> = new Map()
@@ -52,65 +50,43 @@ export class MessagesEventBus {
     }
   }
 
-  trail(handler: (events: ChatBskyConvoGetLog.OutputSchema['logs']) => void) {
-    this.emitter.on('events', handler)
-    return () => {
-      this.emitter.off('events', handler)
-    }
-  }
-
-  trailConvo(
-    convoId: string,
-    handler: (events: ChatBskyConvoGetLog.OutputSchema['logs']) => void,
-  ) {
-    const handle = (events: ChatBskyConvoGetLog.OutputSchema['logs']) => {
-      const convoEvents = events.filter(ev => {
-        if (typeof ev.convoId === 'string' && ev.convoId === convoId) {
-          return ev.convoId === convoId
-        }
-        return false
-      })
-
-      if (convoEvents.length > 0) {
-        handler(convoEvents)
-      }
-    }
-
-    this.emitter.on('events', handle)
-    return () => {
-      this.emitter.off('events', handle)
-    }
-  }
-
   getLatestRev() {
     return this.latestRev
   }
 
-  onConnect(handler: () => void) {
-    this.emitter.on('connect', handler)
-
-    if (
-      this.status === MessagesEventBusStatus.Ready ||
-      this.status === MessagesEventBusStatus.Backgrounded ||
-      this.status === MessagesEventBusStatus.Suspended
-    ) {
-      handler()
-    }
+  on(
+    handler: (event: MessagesEventBusEvent) => void,
+    options: {
+      convoId?: string
+    },
+  ) {
+    const handle = (event: MessagesEventBusEvent) => {
+      if (event.type === 'logs' && options.convoId) {
+        const filteredLogs = event.logs.filter(log => {
+          if (
+            typeof log.convoId === 'string' &&
+            log.convoId === options.convoId
+          ) {
+            return log.convoId === options.convoId
+          }
+          return false
+        })
 
-    return () => {
-      this.emitter.off('connect', handler)
+        if (filteredLogs.length > 0) {
+          handler({
+            ...event,
+            logs: filteredLogs,
+          })
+        }
+      } else {
+        handler(event)
+      }
     }
-  }
-
-  onError(handler: (payload?: MessagesEventBusError) => void) {
-    this.emitter.on('error', handler)
 
-    if (this.status === MessagesEventBusStatus.Error) {
-      handler(this.error)
-    }
+    this.emitter.on('event', handle)
 
     return () => {
-      this.emitter.off('error', handler)
+      this.emitter.off('event', handle)
     }
   }
 
@@ -138,13 +114,13 @@ export class MessagesEventBus {
           case MessagesEventBusDispatchEvent.Ready: {
             this.status = MessagesEventBusStatus.Ready
             this.resetPoll()
-            this.emitter.emit('connect')
+            this.emitter.emit('event', {type: 'connect'})
             break
           }
           case MessagesEventBusDispatchEvent.Background: {
             this.status = MessagesEventBusStatus.Backgrounded
             this.resetPoll()
-            this.emitter.emit('connect')
+            this.emitter.emit('event', {type: 'connect'})
             break
           }
           case MessagesEventBusDispatchEvent.Suspend: {
@@ -153,8 +129,7 @@ export class MessagesEventBus {
           }
           case MessagesEventBusDispatchEvent.Error: {
             this.status = MessagesEventBusStatus.Error
-            this.error = action.payload
-            this.emitter.emit('error', action.payload)
+            this.emitter.emit('event', {type: 'error', error: action.payload})
             break
           }
         }
@@ -174,9 +149,8 @@ export class MessagesEventBus {
           }
           case MessagesEventBusDispatchEvent.Error: {
             this.status = MessagesEventBusStatus.Error
-            this.error = action.payload
             this.stopPoll()
-            this.emitter.emit('error', action.payload)
+            this.emitter.emit('event', {type: 'error', error: action.payload})
             break
           }
           case MessagesEventBusDispatchEvent.UpdatePoll: {
@@ -200,9 +174,8 @@ export class MessagesEventBus {
           }
           case MessagesEventBusDispatchEvent.Error: {
             this.status = MessagesEventBusStatus.Error
-            this.error = action.payload
             this.stopPoll()
-            this.emitter.emit('error', action.payload)
+            this.emitter.emit('event', {type: 'error', error: action.payload})
             break
           }
           case MessagesEventBusDispatchEvent.UpdatePoll: {
@@ -226,9 +199,8 @@ export class MessagesEventBus {
           }
           case MessagesEventBusDispatchEvent.Error: {
             this.status = MessagesEventBusStatus.Error
-            this.error = action.payload
             this.stopPoll()
-            this.emitter.emit('error', action.payload)
+            this.emitter.emit('event', {type: 'error', error: action.payload})
             break
           }
         }
@@ -239,7 +211,6 @@ export class MessagesEventBus {
           case MessagesEventBusDispatchEvent.Resume: {
             // basically reset
             this.status = MessagesEventBusStatus.Initializing
-            this.error = undefined
             this.latestRev = undefined
             this.init()
             break
@@ -403,7 +374,7 @@ export class MessagesEventBus {
 
       if (needsEmit) {
         try {
-          this.emitter.emit('events', batch)
+          this.emitter.emit('event', {type: 'logs', logs: batch})
         } catch (e: any) {
           logger.error(e, {
             context: `${LOGGER_CONTEXT}: process latest events`,
diff --git a/src/state/messages/events/types.ts b/src/state/messages/events/types.ts
index c6be522ae..e65136e4b 100644
--- a/src/state/messages/events/types.ts
+++ b/src/state/messages/events/types.ts
@@ -55,18 +55,15 @@ export type MessagesEventBusDispatch =
       event: MessagesEventBusDispatchEvent.UpdatePoll
     }
 
-export type TrailHandler = (
-  events: ChatBskyConvoGetLog.OutputSchema['logs'],
-) => void
-
-export type RequestPollIntervalHandler = (interval: number) => () => void
-export type OnConnectHandler = (handler: () => void) => () => void
-export type OnDisconnectHandler = (
-  handler: (error?: MessagesEventBusError) => void,
-) => () => void
-
-export type MessagesEventBusEvents = {
-  events: [ChatBskyConvoGetLog.OutputSchema['logs']]
-  connect: undefined
-  error: [MessagesEventBusError] | undefined
-}
+export type MessagesEventBusEvent =
+  | {
+      type: 'connect'
+    }
+  | {
+      type: 'error'
+      error: MessagesEventBusError
+    }
+  | {
+      type: 'logs'
+      logs: ChatBskyConvoGetLog.OutputSchema['logs']
+    }