about summary refs log tree commit diff
path: root/src/state/queries/post-thread.ts
diff options
context:
space:
mode:
Diffstat (limited to 'src/state/queries/post-thread.ts')
-rw-r--r--src/state/queries/post-thread.ts25
1 files changed, 14 insertions, 11 deletions
diff --git a/src/state/queries/post-thread.ts b/src/state/queries/post-thread.ts
index b1bff1493..f7d21a427 100644
--- a/src/state/queries/post-thread.ts
+++ b/src/state/queries/post-thread.ts
@@ -4,6 +4,7 @@ import {
   AppBskyFeedDefs,
   AppBskyFeedGetPostThread,
   AppBskyFeedPost,
+  AtUri,
   ModerationDecision,
   ModerationOpts,
 } from '@atproto/api'
@@ -24,7 +25,11 @@ import {
   findAllPostsInQueryData as findAllPostsInFeedQueryData,
   findAllProfilesInQueryData as findAllProfilesInFeedQueryData,
 } from './post-feed'
-import {embedViewRecordToPostView, getEmbeddedPost} from './util'
+import {
+  didOrHandleUriMatches,
+  embedViewRecordToPostView,
+  getEmbeddedPost,
+} from './util'
 
 const RQKEY_ROOT = 'post-thread'
 export const RQKEY = (uri: string) => [RQKEY_ROOT, uri]
@@ -91,14 +96,10 @@ export function usePostThreadQuery(uri: string | undefined) {
     },
     enabled: !!uri,
     placeholderData: () => {
-      if (!uri) {
-        return undefined
-      }
-      {
-        const post = findPostInQueryData(queryClient, uri)
-        if (post) {
-          return post
-        }
+      if (!uri) return
+      const post = findPostInQueryData(queryClient, uri)
+      if (post) {
+        return post
       }
       return undefined
     },
@@ -271,6 +272,8 @@ export function* findAllPostsInQueryData(
   queryClient: QueryClient,
   uri: string,
 ): Generator<ThreadNode, void> {
+  const atUri = new AtUri(uri)
+
   const queryDatas = queryClient.getQueriesData<ThreadNode>({
     queryKey: [RQKEY_ROOT],
   })
@@ -279,7 +282,7 @@ export function* findAllPostsInQueryData(
       continue
     }
     for (const item of traverseThread(queryData)) {
-      if (item.uri === uri) {
+      if (item.type === 'post' && didOrHandleUriMatches(atUri, item.post)) {
         const placeholder = threadNodeToPlaceholderThread(item)
         if (placeholder) {
           yield placeholder
@@ -287,7 +290,7 @@ export function* findAllPostsInQueryData(
       }
       const quotedPost =
         item.type === 'post' ? getEmbeddedPost(item.post.embed) : undefined
-      if (quotedPost?.uri === uri) {
+      if (quotedPost && didOrHandleUriMatches(atUri, quotedPost)) {
         yield embedViewRecordToPlaceholderThread(quotedPost)
       }
     }