about summary refs log tree commit diff
path: root/src/state/queries/post-feed.ts
diff options
context:
space:
mode:
Diffstat (limited to 'src/state/queries/post-feed.ts')
-rw-r--r--src/state/queries/post-feed.ts122
1 files changed, 77 insertions, 45 deletions
diff --git a/src/state/queries/post-feed.ts b/src/state/queries/post-feed.ts
index 7cf315ef6..7589aa346 100644
--- a/src/state/queries/post-feed.ts
+++ b/src/state/queries/post-feed.ts
@@ -1,3 +1,4 @@
+import {useCallback} from 'react'
 import {AppBskyFeedDefs, AppBskyFeedPost, moderatePost} from '@atproto/api'
 import {
   useInfiniteQuery,
@@ -126,50 +127,57 @@ export function usePostFeedQuery(
       }
     },
     initialPageParam: undefined,
-    getNextPageParam: lastPage => ({
-      api: lastPage.api,
-      cursor: lastPage.cursor,
-    }),
-    select(data) {
-      const tuner = params?.disableTuner
-        ? new NoopFeedTuner()
-        : new FeedTuner(feedTuners)
-      return {
-        pageParams: data.pageParams,
-        pages: data.pages.map(page => ({
-          api: page.api,
-          tuner,
-          cursor: page.cursor,
-          slices: tuner.tune(page.feed).map(slice => ({
-            _reactKey: slice._reactKey,
-            rootUri: slice.rootItem.post.uri,
-            isThread:
-              slice.items.length > 1 &&
-              slice.items.every(
-                item => item.post.author.did === slice.items[0].post.author.did,
-              ),
-            items: slice.items
-              .map((item, i) => {
-                if (
-                  AppBskyFeedPost.isRecord(item.post.record) &&
-                  AppBskyFeedPost.validateRecord(item.post.record).success
-                ) {
-                  return {
-                    _reactKey: `${slice._reactKey}-${i}`,
-                    uri: item.post.uri,
-                    post: item.post,
-                    record: item.post.record,
-                    reason:
-                      i === 0 && slice.source ? slice.source : item.reason,
+    getNextPageParam: lastPage =>
+      lastPage.cursor
+        ? {
+            api: lastPage.api,
+            cursor: lastPage.cursor,
+          }
+        : undefined,
+    select: useCallback(
+      (data: InfiniteData<FeedPageUnselected, RQPageParam>) => {
+        const tuner = params?.disableTuner
+          ? new NoopFeedTuner()
+          : new FeedTuner(feedTuners)
+        return {
+          pageParams: data.pageParams,
+          pages: data.pages.map(page => ({
+            api: page.api,
+            tuner,
+            cursor: page.cursor,
+            slices: tuner.tune(page.feed).map(slice => ({
+              _reactKey: slice._reactKey,
+              rootUri: slice.rootItem.post.uri,
+              isThread:
+                slice.items.length > 1 &&
+                slice.items.every(
+                  item =>
+                    item.post.author.did === slice.items[0].post.author.did,
+                ),
+              items: slice.items
+                .map((item, i) => {
+                  if (
+                    AppBskyFeedPost.isRecord(item.post.record) &&
+                    AppBskyFeedPost.validateRecord(item.post.record).success
+                  ) {
+                    return {
+                      _reactKey: `${slice._reactKey}-${i}`,
+                      uri: item.post.uri,
+                      post: item.post,
+                      record: item.post.record,
+                      reason:
+                        i === 0 && slice.source ? slice.source : item.reason,
+                    }
                   }
-                }
-                return undefined
-              })
-              .filter(Boolean) as FeedPostSliceItem[],
+                  return undefined
+                })
+                .filter(Boolean) as FeedPostSliceItem[],
+            })),
           })),
-        })),
-      }
-    },
+        }
+      },
+      [feedTuners, params?.disableTuner],
+    ),
   })
 }
 
@@ -227,7 +235,20 @@ function createApi(
 export function findPostInQueryData(
   queryClient: QueryClient,
   uri: string,
-): AppBskyFeedDefs.FeedViewPost | undefined {
+): AppBskyFeedDefs.PostView | undefined {
+  const generator = findAllPostsInQueryData(queryClient, uri)
+  const result = generator.next()
+  if (result.done) {
+    return undefined
+  } else {
+    return result.value
+  }
+}
+
+export function* findAllPostsInQueryData(
+  queryClient: QueryClient,
+  uri: string,
+): Generator<AppBskyFeedDefs.PostView, void> {
   const queryDatas = queryClient.getQueriesData<
     InfiniteData<FeedPageUnselected>
   >({
@@ -240,12 +261,23 @@ export function findPostInQueryData(
     for (const page of queryData?.pages) {
       for (const item of page.feed) {
         if (item.post.uri === uri) {
-          return item
+          yield item.post
+        }
+        if (
+          AppBskyFeedDefs.isPostView(item.reply?.parent) &&
+          item.reply?.parent?.uri === uri
+        ) {
+          yield item.reply.parent
+        }
+        if (
+          AppBskyFeedDefs.isPostView(item.reply?.root) &&
+          item.reply?.root?.uri === uri
+        ) {
+          yield item.reply.root
         }
       }
     }
   }
-  return undefined
 }
 
 function assertSomePostsPassModeration(feed: AppBskyFeedDefs.FeedViewPost[]) {