diff --git a/src/components/Grid/useDataGridSource.tsx b/src/components/Grid/useDataGridSource.tsx index cb6486f033..712761f555 100644 --- a/src/components/Grid/useDataGridSource.tsx +++ b/src/components/Grid/useDataGridSource.tsx @@ -18,6 +18,7 @@ import { useGridApiContext, useGridApiRef, } from '@mui/x-data-grid-pro'; +import { GridInitialStatePro } from '@mui/x-data-grid-pro/models/gridStatePro'; import { groupBy, Nil, setOf } from '@seedcompany/common'; import { useDebounceFn, @@ -39,6 +40,7 @@ import type { Get, Paths, SetNonNullable } from 'type-fest'; import { type PaginatedListInput, type SortableListInput } from '~/api'; import type { Order } from '~/api/schema/schema.graphql'; import { lowerCase, upperCase } from '~/common'; +import { usePersistedGridState } from '~/hooks/usePersistedGridState'; import { convertMuiFiltersToApi, FilterShape } from './convertMuiFiltersToApi'; type ListInput = SetNonNullable< @@ -52,6 +54,11 @@ interface PaginatedListOutput { total: number; } +interface SessionStorageProps { + key: string; + defaultValue: GridInitialStatePro; +} + type PathsMatching = { [K in Paths]: K extends string ? Get extends List @@ -99,6 +106,7 @@ export const useDataGridSource = < initialInput, keyArgs = defaultKeyArgs, apiRef: apiRefInput, + sessionStorageProps: sessionStateProps, }: { query: DocumentNode; variables: NoInfer; @@ -106,6 +114,7 @@ export const useDataGridSource = < initialInput?: Partial, 'page'>>; keyArgs?: string[]; apiRef?: MutableRefObject; + sessionStorageProps: SessionStorageProps; }) => { const initialInputRef = useLatest(initialInput); // eslint-disable-next-line react-hooks/rules-of-hooks -- we'll assume this doesn't change between renders @@ -198,6 +207,13 @@ export const useDataGridSource = < ); }; + const [savedGridState = {}, onStateChange, persistedFilterModel] = + usePersistedGridState({ + key: sessionStateProps.key, + apiRef: apiRef, + defaultValue: sessionStateProps.defaultValue, + }); + // State for current sorting & filtering const [initialSort] = useState((): ViewState['sortModel'] => [ { @@ -215,6 +231,7 @@ export const useDataGridSource = < }), } ); + const persist = useDebounceFn((next: ViewState) => { const filterModel = { // Strip out filters for columns that shouldn't be persisted @@ -234,9 +251,14 @@ export const useDataGridSource = < }); }); const [view, reallySetView] = useState((): ViewState => { - const { apiFilterModel: _, ...rest } = storedView!; // not null because we give a default value + const { apiFilterModel, ...rest } = storedView!; // not null because we give a default value return { ...rest, + filterModel: merge( + {}, + rest.filterModel, + savedGridState.filter?.filterModel + ), apiSortModel: rest.sortModel, }; }); @@ -253,23 +275,29 @@ export const useDataGridSource = < // Convert the view state to the input for the GQL query const input = useMemo( - () => ({ - ...defaultInitialInput, - ...initialInputRef.current, - count: initialInputRef.current?.count ?? defaultInitialInput.count, - ...(view.apiSortModel?.[0] && { - sort: view.apiSortModel[0].field, - order: upperCase(view.apiSortModel[0].sort!), - }), - // eslint-disable-next-line no-extra-boolean-cast - filter: Boolean(apiRef.current.instanceId) - ? convertMuiFiltersToApi( - apiRef.current, - view.filterModel, - initialInputRef.current?.filter - ) - : storedView?.apiFilterModel, - }), + () => { + const initialFilterModel = { + ...storedView?.apiFilterModel, + ...persistedFilterModel, + }; + return { + ...defaultInitialInput, + ...initialInputRef.current, + count: initialInputRef.current?.count ?? defaultInitialInput.count, + ...(view.apiSortModel?.[0] && { + sort: view.apiSortModel[0].field, + order: upperCase(view.apiSortModel[0].sort!), + }), + // eslint-disable-next-line no-extra-boolean-cast + filter: Boolean(apiRef.current.instanceId) + ? convertMuiFiltersToApi( + apiRef.current, + view.filterModel, + initialInputRef.current?.filter + ) + : initialFilterModel, + }; + }, // eslint-disable-next-line react-hooks/exhaustive-deps [apiRef, initialInputRef, view.apiSortModel, view.filterModel] ); @@ -427,10 +455,12 @@ export const useDataGridSource = < rows, loading, rowCount: total, + initialState: savedGridState, sortModel: view.sortModel, filterModel: view.filterModel, hideFooterPagination: true, onFetchRows: onFetchRows.run, + onStateChange, onSortModelChange, onFilterModelChange, paginationMode: total != null ? 'server' : 'client', // Not used, but prevents row count warning. diff --git a/src/hooks/usePersistedGridState.tsx b/src/hooks/usePersistedGridState.tsx new file mode 100644 index 0000000000..7680458e1a --- /dev/null +++ b/src/hooks/usePersistedGridState.tsx @@ -0,0 +1,65 @@ +import { GridApiPro, GridEventListener } from '@mui/x-data-grid-pro'; +import { GridInitialStatePro } from '@mui/x-data-grid-pro/models/gridStatePro'; +import { useDebounceFn, usePrevious, useSessionStorageState } from 'ahooks'; +import { isEqual } from 'lodash'; +import { MutableRefObject, useEffect, useRef } from 'react'; +import { convertMuiFiltersToApi } from '~/components/Grid/convertMuiFiltersToApi'; + +interface UsePersistedGridStateOptions { + key: string; + apiRef: MutableRefObject; + defaultValue: GridInitialStatePro; +} + +export const usePersistedGridState = ({ + key, + apiRef, + defaultValue, +}: UsePersistedGridStateOptions) => { + const isRestoringState = useRef(true); + + const [savedGridState, setSavedGridState] = useSessionStorageState(key, { + defaultValue, + }); + + const [persistedFilterModel, setPersistedFilterModel] = + useSessionStorageState>(`${key}-api-filter`, {}); + + const prevGridState = usePrevious( + savedGridState, + (prev, next) => !isEqual(prev, next) + ); + + const { run: handleStateChange } = useDebounceFn( + () => { + const gridState = apiRef.current.exportState(); + + setPersistedFilterModel((prev) => + isEqual(prev, gridState) + ? prev + : convertMuiFiltersToApi( + apiRef.current, + gridState.filter?.filterModel + ) + ); + setSavedGridState((prev) => + isEqual(prev, gridState) ? prev || defaultValue : gridState + ); + }, + { wait: 500, maxWait: 500 } + ); + + const onStateChange: GridEventListener<'stateChange'> = () => { + handleStateChange(); + }; + + useEffect(() => { + if (isRestoringState.current) { + isRestoringState.current = false; + } else if (savedGridState && !isEqual(savedGridState, prevGridState)) { + apiRef.current.restoreState(savedGridState); + } + }, [savedGridState, apiRef, prevGridState]); + + return [savedGridState, onStateChange, persistedFilterModel] as const; +}; diff --git a/src/scenes/Dashboard/ProgressReportsWidget/ProgressReportsExpandedGrid.tsx b/src/scenes/Dashboard/ProgressReportsWidget/ProgressReportsExpandedGrid.tsx index df3739029c..816a8a97d0 100644 --- a/src/scenes/Dashboard/ProgressReportsWidget/ProgressReportsExpandedGrid.tsx +++ b/src/scenes/Dashboard/ProgressReportsWidget/ProgressReportsExpandedGrid.tsx @@ -127,7 +127,6 @@ export const ProgressReportsExpandedGrid = ( }), [onMouseDown] ); - return ( { const source = useMemo(() => { @@ -210,9 +211,14 @@ export const ProgressReportsGrid = ({ }, } as const; }, [quarter]); + const [dataGridProps] = useDataGridSource({ ...source, apiRef: props.apiRef, + sessionStorageProps: { + key: 'progress-reports-grid', + defaultValue: initialState, + }, }); const slots = useMemo( diff --git a/src/scenes/Partners/Detail/Tabs/Engagements/PartnerDetailEngagements.tsx b/src/scenes/Partners/Detail/Tabs/Engagements/PartnerDetailEngagements.tsx index 1bdc114041..5fa88dcd81 100644 --- a/src/scenes/Partners/Detail/Tabs/Engagements/PartnerDetailEngagements.tsx +++ b/src/scenes/Partners/Detail/Tabs/Engagements/PartnerDetailEngagements.tsx @@ -31,6 +31,10 @@ export const PartnerDetailEngagements = () => { initialInput: { sort: EngagementColumns[0]!.field, }, + sessionStorageProps: { + key: `partners-engagements-grid-state-${partnerId}`, + defaultValue: EngagementInitialState, + }, }); const slots = useMemo( @@ -53,7 +57,6 @@ export const PartnerDetailEngagements = () => { slots={slots} slotProps={slotProps} columns={EngagementColumns} - initialState={EngagementInitialState} headerFilters hideFooter sx={[flexLayout, noHeaderFilterButtons, noFooter]} diff --git a/src/scenes/Partners/Detail/Tabs/Projects/PartnerDetailProjects.tsx b/src/scenes/Partners/Detail/Tabs/Projects/PartnerDetailProjects.tsx index 36dc00b6ef..475249b668 100644 --- a/src/scenes/Partners/Detail/Tabs/Projects/PartnerDetailProjects.tsx +++ b/src/scenes/Partners/Detail/Tabs/Projects/PartnerDetailProjects.tsx @@ -30,7 +30,6 @@ import { export const PartnerDetailProjects = () => { const { partnerId = '' } = useParams(); - const [props] = useDataGridSource({ query: PartnerProjectsDocument, variables: { partnerId }, @@ -38,6 +37,10 @@ export const PartnerDetailProjects = () => { initialInput: { sort: 'name', }, + sessionStorageProps: { + key: `partners-projects-grid-state-${partnerId}`, + defaultValue: ProjectInitialState, + }, }); const slots = useMemo( @@ -60,7 +63,6 @@ export const PartnerDetailProjects = () => { slots={slots} slotProps={slotProps} columns={PartnerProjectColumns} - initialState={ProjectInitialState} headerFilters hideFooter sx={[flexLayout, noHeaderFilterButtons, noFooter]} diff --git a/src/scenes/Projects/List/EngagementsPanel.tsx b/src/scenes/Projects/List/EngagementsPanel.tsx index 2070c39671..6644e28e35 100644 --- a/src/scenes/Projects/List/EngagementsPanel.tsx +++ b/src/scenes/Projects/List/EngagementsPanel.tsx @@ -27,6 +27,10 @@ export const EngagementsPanel = () => { initialInput: { sort: EngagementColumns[0]!.field, }, + sessionStorageProps: { + key: 'engagements-grid', + defaultValue: EngagementInitialState, + }, }); const slots = useMemo( @@ -49,7 +53,6 @@ export const EngagementsPanel = () => { slots={slots} slotProps={slotProps} columns={EngagementColumns} - initialState={EngagementInitialState} headerFilters hideFooter sx={[flexLayout, noHeaderFilterButtons, noFooter]} diff --git a/src/scenes/Projects/List/ProjectsPanel.tsx b/src/scenes/Projects/List/ProjectsPanel.tsx index e107b8b553..ab9a111e0f 100644 --- a/src/scenes/Projects/List/ProjectsPanel.tsx +++ b/src/scenes/Projects/List/ProjectsPanel.tsx @@ -27,6 +27,10 @@ export const ProjectsPanel = () => { initialInput: { sort: 'name', }, + sessionStorageProps: { + key: 'projects-grid', + defaultValue: ProjectInitialState, + }, }); const slots = useMemo( @@ -49,7 +53,6 @@ export const ProjectsPanel = () => { slots={slots} slotProps={slotProps} columns={ProjectColumns} - initialState={ProjectInitialState} headerFilters hideFooter sx={[flexLayout, noHeaderFilterButtons, noFooter]}