about summary refs log tree commit diff
path: root/src/state/cache/post-shadow.ts
diff options
context:
space:
mode:
Diffstat (limited to 'src/state/cache/post-shadow.ts')
-rw-r--r--src/state/cache/post-shadow.ts18
1 files changed, 12 insertions, 6 deletions
diff --git a/src/state/cache/post-shadow.ts b/src/state/cache/post-shadow.ts
index 7cf72fae4..6225cbdba 100644
--- a/src/state/cache/post-shadow.ts
+++ b/src/state/cache/post-shadow.ts
@@ -1,13 +1,14 @@
-import {useEffect, useState, useMemo} from 'react'
-import EventEmitter from 'eventemitter3'
+import {useEffect, useMemo, useState} from 'react'
 import {AppBskyFeedDefs} from '@atproto/api'
+import {QueryClient} from '@tanstack/react-query'
+import EventEmitter from 'eventemitter3'
+
 import {batchedUpdates} from '#/lib/batchedUpdates'
-import {Shadow, castAsShadow} from './types'
 import {findAllPostsInQueryData as findAllPostsInNotifsQueryData} from '../queries/notifications/feed'
 import {findAllPostsInQueryData as findAllPostsInFeedQueryData} from '../queries/post-feed'
 import {findAllPostsInQueryData as findAllPostsInThreadQueryData} from '../queries/post-thread'
 import {findAllPostsInQueryData as findAllPostsInSearchQueryData} from '../queries/search-posts'
-import {queryClient} from 'lib/react-query'
+import {castAsShadow, Shadow} from './types'
 export type {Shadow} from './types'
 
 export interface PostShadow {
@@ -93,8 +94,12 @@ function mergeShadow(
   })
 }
 
-export function updatePostShadow(uri: string, value: Partial<PostShadow>) {
-  const cachedPosts = findPostsInCache(uri)
+export function updatePostShadow(
+  queryClient: QueryClient,
+  uri: string,
+  value: Partial<PostShadow>,
+) {
+  const cachedPosts = findPostsInCache(queryClient, uri)
   for (let post of cachedPosts) {
     shadows.set(post, {...shadows.get(post), ...value})
   }
@@ -104,6 +109,7 @@ export function updatePostShadow(uri: string, value: Partial<PostShadow>) {
 }
 
 function* findPostsInCache(
+  queryClient: QueryClient,
   uri: string,
 ): Generator<AppBskyFeedDefs.PostView, void> {
   for (let post of findAllPostsInFeedQueryData(queryClient, uri)) {