From 201a9d53e85841e337dd9059bf6be91cf9818061 Mon Sep 17 00:00:00 2001 From: detachhead Date: Thu, 6 Feb 2025 22:59:21 +1000 Subject: [PATCH] rework lsp notebook support to support baseline --- .../pyright-internal/src/analyzer/program.ts | 39 ++- .../src/analyzer/programTypes.ts | 3 +- .../src/analyzer/sourceFile.ts | 20 +- .../src/analyzer/sourceFileInfo.ts | 1 - .../src/commands/quickActionCommand.ts | 6 +- .../src/common/languageServerInterface.ts | 3 +- .../src/common/serviceProviderExtensions.ts | 4 +- .../src/common/uri/uriUtils.ts | 7 +- .../src/common/workspaceEditUtils.ts | 51 ++- .../src/languageServerBase.ts | 318 ++++++++++++------ .../languageService/callHierarchyProvider.ts | 33 +- .../src/languageService/codeActionProvider.ts | 7 +- .../languageService/documentSymbolProvider.ts | 17 +- .../src/languageService/navigationUtils.ts | 19 +- .../src/languageService/referencesProvider.ts | 22 +- .../src/languageService/renameProvider.ts | 18 +- .../workspaceSymbolProvider.ts | 7 +- .../src/realLanguageServer.ts | 9 +- .../harness/fourslash/testLanguageService.ts | 7 +- .../src/tests/harness/fourslash/testState.ts | 38 ++- .../src/tests/sourceFile.test.ts | 3 +- .../src/tests/workspaceEditUtils.test.ts | 8 +- 22 files changed, 422 insertions(+), 218 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/program.ts b/packages/pyright-internal/src/analyzer/program.ts index 98bfde3f04..223d35e50d 100644 --- a/packages/pyright-internal/src/analyzer/program.ts +++ b/packages/pyright-internal/src/analyzer/program.ts @@ -348,26 +348,25 @@ export class Program { importName, isThirdPartyImport, isInPyTypedPackage, - this.baselineHandler, this._editModeTracker, + this.baselineHandler, + () => sourceFileInfo.cellIndex(), this._console, this._logTracker, IPythonMode.CellDocs ); - sourceFile.setCellIndex(index); - sourceFileInfos.push( - new SourceFileInfo( - sourceFile, - isTypeshedFile, - isThirdPartyImport, - isInPyTypedPackage, - this._editModeTracker, - { - isTracked: true, - chainedSourceFile: sourceFileInfos[index - 1], - } - ) + const sourceFileInfo = new SourceFileInfo( + sourceFile, + isTypeshedFile, + isThirdPartyImport, + isInPyTypedPackage, + this._editModeTracker, + { + isTracked: true, + chainedSourceFile: sourceFileInfos[index - 1], + } ); + sourceFileInfos.push(sourceFileInfo); }); } else { const importName = this._getImportNameForNewSourceFile(fileUri); @@ -380,8 +379,9 @@ export class Program { importName, isThirdPartyImport, isInPyTypedPackage, - this.baselineHandler, this._editModeTracker, + this.baselineHandler, + () => undefined, this._console, this._logTracker ); @@ -412,8 +412,9 @@ export class Program { moduleImportInfo.moduleName, /* isThirdPartyImport */ false, moduleImportInfo.isThirdPartyPyTypedPresent, - this.baselineHandler, this._editModeTracker, + this.baselineHandler, + () => sourceFileInfo?.cellIndex(), this._console, this._logTracker, options?.ipythonMode ?? IPythonMode.None @@ -1536,8 +1537,9 @@ export class Program { moduleImportInfo.moduleName, importInfo.isThirdPartyImport, importInfo.isPyTypedPresent, - this.baselineHandler, this._editModeTracker, + this.baselineHandler, + () => importedFileInfo?.cellIndex(), this._console, this._logTracker ); @@ -1679,8 +1681,9 @@ export class Program { moduleImportInfo.moduleName, /* isThirdPartyImport */ false, /* isInPyTypedPackage */ false, - this.baselineHandler, this._editModeTracker, + this.baselineHandler, + () => sourceFileInfo.cellIndex(), this._console, this._logTracker ); diff --git a/packages/pyright-internal/src/analyzer/programTypes.ts b/packages/pyright-internal/src/analyzer/programTypes.ts index f6613f194e..7de68cb62b 100644 --- a/packages/pyright-internal/src/analyzer/programTypes.ts +++ b/packages/pyright-internal/src/analyzer/programTypes.ts @@ -19,8 +19,9 @@ export interface ISourceFileFactory { moduleName: string, isThirdPartyImport: boolean, isThirdPartyPyTypedPresent: boolean, - baselineHandler: BaselineHandler, editMode: SourceFileEditMode, + baselineHandler: BaselineHandler, + getCellIndex: () => number | undefined, console?: ConsoleInterface, logTracker?: LogTracker, ipythonMode?: IPythonMode diff --git a/packages/pyright-internal/src/analyzer/sourceFile.ts b/packages/pyright-internal/src/analyzer/sourceFile.ts index 412c817166..bc323cf61c 100644 --- a/packages/pyright-internal/src/analyzer/sourceFile.ts +++ b/packages/pyright-internal/src/analyzer/sourceFile.ts @@ -203,11 +203,6 @@ class WriteableData { parserOutput: ParserOutput | undefined; - /** - * this is only writable in the language server because when the user moves cells around, the index changes - */ - cellIndex: number | undefined; - constructor() { // Empty } @@ -317,6 +312,9 @@ export class SourceFile { isThirdPartyPyTypedPresent: boolean, editMode: SourceFileEditMode, private _baselineHandler: BaselineHandler, + // this is kinda weird but it's necessary because the chained file is stored on SourceFileInfo and + // accessing it from here would cause a circular dependency + private _getCellIndex: () => number | undefined, console?: ConsoleInterface, logTracker?: LogTracker, ipythonMode?: IPythonMode @@ -539,10 +537,6 @@ export class SourceFile { return this._writableData.clientDocumentContents; } - // TODO: this sucks. ideally SourceFile would just have access to its chained source files, but there's a bunch of machinery - // around writable data that i'm too scared to touch - setCellIndex = (value: number) => (this._writableData.cellIndex = value); - /** * gets the content of the source file. if it's a notebook, the content of this source file's {@link _ipythonCellIndex} is returned */ @@ -556,7 +550,7 @@ export class SourceFile { // Otherwise, get content from file system. return getFileContent(this.fileSystem, this._uri, this._console); } - const cellIndex = this._writableData.cellIndex; + const cellIndex = this._getCellIndex(); if (cellIndex === undefined) { throw new Error(`something went wrong, failed to get cell index for ${this._uri}`); } @@ -1314,11 +1308,7 @@ export class SourceFile { // Now add in the "unnecessary type ignore" diagnostics. diagList = diagList.concat(unnecessaryTypeIgnoreDiags); - diagList = this._baselineHandler.sortDiagnosticsAndMatchBaseline( - this._uri, - this._writableData.cellIndex, - diagList - ); + diagList = this._baselineHandler.sortDiagnosticsAndMatchBaseline(this._uri, this._getCellIndex(), diagList); // If we're not returning any diagnostics, filter out all of // the errors and warnings, leaving only the unreachable code diff --git a/packages/pyright-internal/src/analyzer/sourceFileInfo.ts b/packages/pyright-internal/src/analyzer/sourceFileInfo.ts index a44fec6fbb..027ff950fb 100644 --- a/packages/pyright-internal/src/analyzer/sourceFileInfo.ts +++ b/packages/pyright-internal/src/analyzer/sourceFileInfo.ts @@ -123,7 +123,6 @@ export class SourceFileInfo { result++; chainedFile = chainedFile.chainedSourceFile; } - this.sourceFile.setCellIndex(result); return result; }; diff --git a/packages/pyright-internal/src/commands/quickActionCommand.ts b/packages/pyright-internal/src/commands/quickActionCommand.ts index c4086cd30c..4f4860d39e 100644 --- a/packages/pyright-internal/src/commands/quickActionCommand.ts +++ b/packages/pyright-internal/src/commands/quickActionCommand.ts @@ -32,7 +32,11 @@ export class QuickActionCommand implements ServerCommand { return performQuickAction(p, docUri, params.command, otherArgs, token); }, token); - return convertToWorkspaceEdit(workspace.service.fs, convertToFileTextEdits(docUri, editActions ?? [])); + return convertToWorkspaceEdit( + this._ls, + workspace.service.fs, + convertToFileTextEdits(docUri, editActions ?? []) + ); } } } diff --git a/packages/pyright-internal/src/common/languageServerInterface.ts b/packages/pyright-internal/src/common/languageServerInterface.ts index a2b459e7cd..38bac27309 100644 --- a/packages/pyright-internal/src/common/languageServerInterface.ts +++ b/packages/pyright-internal/src/common/languageServerInterface.ts @@ -15,7 +15,7 @@ import { DiagnosticBooleanOverridesMap, DiagnosticSeverityOverridesMap } from '. import { SignatureDisplayType } from './configOptions'; import { ConsoleInterface, LogLevel } from './console'; import { TaskListToken } from './diagnostic'; -import { FileSystem } from './fileSystem'; +import { FileSystem, ReadOnlyFileSystem } from './fileSystem'; import { FileWatcherHandler } from './fileWatcher'; import { ServiceProvider } from './serviceProvider'; import { Uri } from './uri/uri'; @@ -138,6 +138,7 @@ export interface LanguageServerBaseInterface { export interface LanguageServerInterface extends LanguageServerBaseInterface { getWorkspaceForFile(fileUri: Uri, pythonPath?: Uri): Promise; + convertUriToLspUriString: (fs: ReadOnlyFileSystem, uri: Uri) => string; readonly documentsWithDiagnostics: Record; } diff --git a/packages/pyright-internal/src/common/serviceProviderExtensions.ts b/packages/pyright-internal/src/common/serviceProviderExtensions.ts index 2a9d8bcae5..0637b02701 100644 --- a/packages/pyright-internal/src/common/serviceProviderExtensions.ts +++ b/packages/pyright-internal/src/common/serviceProviderExtensions.ts @@ -105,8 +105,9 @@ const DefaultSourceFileFactory: ISourceFileFactory = { moduleName: string, isThirdPartyImport: boolean, isThirdPartyPyTypedPresent: boolean, - baselineHandler: BaselineHandler, editMode: SourceFileEditMode, + baselineHandler: BaselineHandler, + getCellIndex: () => number | undefined, console?: ConsoleInterface, logTracker?: LogTracker, ipythonMode?: IPythonMode @@ -119,6 +120,7 @@ const DefaultSourceFileFactory: ISourceFileFactory = { isThirdPartyPyTypedPresent, editMode, baselineHandler, + getCellIndex, console, logTracker, ipythonMode diff --git a/packages/pyright-internal/src/common/uri/uriUtils.ts b/packages/pyright-internal/src/common/uri/uriUtils.ts index 09a33ab795..ad1a1c2bec 100644 --- a/packages/pyright-internal/src/common/uri/uriUtils.ts +++ b/packages/pyright-internal/src/common/uri/uriUtils.ts @@ -276,7 +276,7 @@ export function getWildcardRoot(root: Uri, fileSpec: string): Uri { } export function hasPythonExtension(uri: Uri) { - return uri.hasExtension('.py') || uri.hasExtension('.pyi'); + return uri.hasExtension('.py') || uri.hasExtension('.pyi') || uri.hasExtension('.ipynb'); } export function getFileSpec(root: Uri, fileSpec: string): FileSpec { @@ -386,11 +386,6 @@ export function getRootUri(csdOrSp: CaseSensitivityDetector | ServiceProvider): return undefined; } -export function convertUriToLspUriString(fs: ReadOnlyFileSystem, uri: Uri): string { - // Convert to a URI string that the LSP client understands (mapped files are only local to the server). - return fs.getOriginalUri(uri).toString(); -} - export namespace UriEx { export function file(path: string): Uri; export function file(path: string, isCaseSensitive: boolean, checkRelative?: boolean): Uri; diff --git a/packages/pyright-internal/src/common/workspaceEditUtils.ts b/packages/pyright-internal/src/common/workspaceEditUtils.ts index 3558a70557..37e94e666b 100644 --- a/packages/pyright-internal/src/common/workspaceEditUtils.ts +++ b/packages/pyright-internal/src/common/workspaceEditUtils.ts @@ -28,7 +28,7 @@ import { convertRangeToTextRange, convertTextRangeToRange } from './positionUtil import { TextRange } from './textRange'; import { TextRangeCollection } from './textRangeCollection'; import { Uri } from './uri/uri'; -import { convertUriToLspUriString } from './uri/uriUtils'; +import { LanguageServerInterface } from './languageServerInterface'; export function convertToTextEdits(editActions: TextEditAction[]): TextEdit[] { return editActions.map((editAction) => ({ @@ -47,9 +47,18 @@ export function convertToFileTextEdits(fileUri: Uri, editActions: TextEditAction return editActions.map((a) => ({ fileUri, ...a })); } -export function convertToWorkspaceEdit(fs: ReadOnlyFileSystem, edits: FileEditAction[]): WorkspaceEdit; -export function convertToWorkspaceEdit(fs: ReadOnlyFileSystem, edits: FileEditActions): WorkspaceEdit; export function convertToWorkspaceEdit( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + edits: FileEditAction[] +): WorkspaceEdit; +export function convertToWorkspaceEdit( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + edits: FileEditActions +): WorkspaceEdit; +export function convertToWorkspaceEdit( + ls: LanguageServerInterface, fs: ReadOnlyFileSystem, edits: FileEditActions, changeAnnotations: { @@ -58,6 +67,7 @@ export function convertToWorkspaceEdit( defaultAnnotationId: string ): WorkspaceEdit; export function convertToWorkspaceEdit( + ls: LanguageServerInterface, fs: ReadOnlyFileSystem, edits: FileEditActions | FileEditAction[], changeAnnotations?: { @@ -66,15 +76,20 @@ export function convertToWorkspaceEdit( defaultAnnotationId = 'default' ): WorkspaceEdit { if (isArray(edits)) { - return _convertToWorkspaceEditWithChanges(fs, edits); + return _convertToWorkspaceEditWithChanges(ls, fs, edits); } - return _convertToWorkspaceEditWithDocumentChanges(fs, edits, changeAnnotations, defaultAnnotationId); + return _convertToWorkspaceEditWithDocumentChanges(ls, fs, edits, changeAnnotations, defaultAnnotationId); } -export function appendToWorkspaceEdit(fs: ReadOnlyFileSystem, edits: FileEditAction[], workspaceEdit: WorkspaceEdit) { +export function appendToWorkspaceEdit( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + edits: FileEditAction[], + workspaceEdit: WorkspaceEdit +) { edits.forEach((edit) => { - const uri = convertUriToLspUriString(fs, edit.fileUri); + const uri = ls.convertUriToLspUriString(fs, edit.fileUri); workspaceEdit.changes![uri] = workspaceEdit.changes![uri] || []; workspaceEdit.changes![uri].push({ range: edit.range, newText: edit.replacementText }); }); @@ -167,6 +182,7 @@ export function applyDocumentChanges(program: EditableProgram, fileInfo: SourceF } export function generateWorkspaceEdit( + ls: LanguageServerInterface, fs: ReadOnlyFileSystem, originalService: AnalyzerService, clonedService: AnalyzerService, @@ -191,7 +207,7 @@ export function generateWorkspaceEdit( continue; } - edits.changes![convertUriToLspUriString(fs, uri)] = [ + edits.changes![ls.convertUriToLspUriString(fs, uri)] = [ { range: convertTextRangeToRange(parseResults.parserOutput.parseTree, parseResults.tokenizerOutput.lines), newText: final.getFileContent() ?? '', @@ -202,16 +218,21 @@ export function generateWorkspaceEdit( return edits; } -function _convertToWorkspaceEditWithChanges(fs: ReadOnlyFileSystem, edits: FileEditAction[]) { +function _convertToWorkspaceEditWithChanges( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + edits: FileEditAction[] +) { const workspaceEdit: WorkspaceEdit = { changes: {}, }; - appendToWorkspaceEdit(fs, edits, workspaceEdit); + appendToWorkspaceEdit(ls, fs, edits, workspaceEdit); return workspaceEdit; } function _convertToWorkspaceEditWithDocumentChanges( + ls: LanguageServerInterface, fs: ReadOnlyFileSystem, editActions: FileEditActions, changeAnnotations?: { @@ -231,7 +252,7 @@ function _convertToWorkspaceEditWithDocumentChanges( case 'create': workspaceEdit.documentChanges!.push( CreateFile.create( - convertUriToLspUriString(fs, operation.fileUri), + ls.convertUriToLspUriString(fs, operation.fileUri), /* options */ undefined, defaultAnnotationId ) @@ -246,7 +267,7 @@ function _convertToWorkspaceEditWithDocumentChanges( } // Text edit's file path must refer to original file paths unless it is a new file just created. - const mapPerFile = createMapFromItems(editActions.edits, (e) => convertUriToLspUriString(fs, e.fileUri)); + const mapPerFile = createMapFromItems(editActions.edits, (e) => ls.convertUriToLspUriString(fs, e.fileUri)); for (const [uri, value] of mapPerFile) { workspaceEdit.documentChanges!.push( TextDocumentEdit.create( @@ -269,8 +290,8 @@ function _convertToWorkspaceEditWithDocumentChanges( case 'rename': workspaceEdit.documentChanges!.push( RenameFile.create( - convertUriToLspUriString(fs, operation.oldFileUri), - convertUriToLspUriString(fs, operation.newFileUri), + ls.convertUriToLspUriString(fs, operation.oldFileUri), + ls.convertUriToLspUriString(fs, operation.newFileUri), /* options */ undefined, defaultAnnotationId ) @@ -279,7 +300,7 @@ function _convertToWorkspaceEditWithDocumentChanges( case 'delete': workspaceEdit.documentChanges!.push( DeleteFile.create( - convertUriToLspUriString(fs, operation.fileUri), + ls.convertUriToLspUriString(fs, operation.fileUri), /* options */ undefined, defaultAnnotationId ) diff --git a/packages/pyright-internal/src/languageServerBase.ts b/packages/pyright-internal/src/languageServerBase.ts index beca3f1484..5b8dd146d7 100644 --- a/packages/pyright-internal/src/languageServerBase.ts +++ b/packages/pyright-internal/src/languageServerBase.ts @@ -75,11 +75,11 @@ import { DidChangeNotebookDocumentParams, DidCloseNotebookDocumentParams, DidOpenNotebookDocumentParams, + DidSaveNotebookDocumentParams, InlayHint, InlayHintParams, SemanticTokens, SemanticTokensParams, - TextDocumentItem, WillSaveTextDocumentParams, } from 'vscode-languageserver-protocol'; import { ResultProgressReporter } from 'vscode-languageserver'; @@ -125,7 +125,6 @@ import { ServiceKeys } from './common/serviceKeys'; import { ServiceProvider } from './common/serviceProvider'; import { DocumentRange, Position, Range } from './common/textRange'; import { Uri } from './common/uri/uri'; -import { convertUriToLspUriString } from './common/uri/uriUtils'; import { AnalyzerServiceExecutor } from './languageService/analyzerServiceExecutor'; import { CallHierarchyProvider } from './languageService/callHierarchyProvider'; import { InlayHintsProvider } from './languageService/inlayHintsProvider'; @@ -151,6 +150,7 @@ import { RenameUsageFinder } from './analyzer/renameUsageFinder'; import { BaselineHandler } from './baseline'; import { assert } from './common/debug'; import { AutoImporter, buildModuleSymbolsMap } from './languageService/autoImporter'; +import { zip } from 'lodash'; export abstract class LanguageServerBase implements LanguageServerInterface, Disposable @@ -196,6 +196,7 @@ export abstract class LanguageServerBase(); + private readonly _openCells = new Map(); protected readonly fs: FileSystem; protected readonly caseSensitiveDetector: CaseSensitivityDetector; @@ -267,6 +268,7 @@ export abstract class LanguageServerBase { + // Convert to a URI string that the LSP client understands (mapped files are only local to the server). + if (uri.fragment && uri.scheme !== 'vscode-notebook-cell') { + // if it's a notebook cell we need to figure out the open uri matching the index, because it changes + // when cells are rearranged + const result = this._openCells.get(uri.withFragment('').key)?.[Number(uri.fragment)]; + if (!result) { + throw new Error(`failed to get lsp uri for cell at index ${uri.fragment} or ${uri}`); + } + return result.uri; + } + return fs.getOriginalUri(uri).toString(); + }; + protected abstract executeCommand(params: ExecuteCommandParams, token: CancellationToken): Promise; protected abstract isLongRunningCommand(command: string): boolean; @@ -529,15 +545,13 @@ export abstract class LanguageServerBase this.onDidOpenTextDocument(params)); this.connection.onDidChangeTextDocument(async (params) => this.onDidChangeTextDocument(params)); this.connection.onDidCloseTextDocument(async (params) => this.onDidCloseTextDocument(params)); - this.connection.notebooks.synchronization.onDidOpenNotebookDocument((params) => - this.onDidOpenNotebookDocument(params) - ); - this.connection.notebooks.synchronization.onDidChangeNotebookDocument((params) => - this.onDidChangeNotebookDocument(params) - ); - this.connection.notebooks.synchronization.onDidCloseNotebookDocument((params) => - this.onDidCloseNotebookDocument(params) - ); + this.connection.notebooks.synchronization.onDidOpenNotebookDocument(this.onDidOpenNotebookDocument); + this.connection.notebooks.synchronization.onDidChangeNotebookDocument(this.onDidChangeNotebookDocument); + this.connection.notebooks.synchronization.onDidCloseNotebookDocument(this.onDidCloseNotebookDocument); + // this is incosnsitent because non-notebook files use onWillSaveTextDocument instead of onDidSaveTextDocument. + // see https://github.com/microsoft/language-server-protocol/issues/2095 i don't think it will casue any issues, + // but it just takes slightly longer to determine that it needs to update the baseline file i think + this.connection.notebooks.synchronization.onDidSaveNotebookDocument(this.onSaveNotebookDocument); this.connection.onDidChangeWatchedFiles((params) => this.onDidChangeWatchedFiles(params)); this.connection.workspace.onWillRenameFiles(this.onRenameFiles); this.connection.onWillSaveTextDocument(this.onSaveTextDocument); @@ -624,6 +638,7 @@ export abstract class LanguageServerBase this.canNavigateToFile(loc.uri, workspace.service.fs)) - .map((loc) => Location.create(convertUriToLspUriString(workspace.service.fs, loc.uri), loc.range)); + .map((loc) => Location.create(this.convertUriToLspUriString(workspace.service.fs, loc.uri), loc.range)); } protected async onReferences( @@ -784,7 +799,11 @@ export abstract class LanguageServerBase | undefined, createDocumentRange?: (uri: Uri, result: CollectionResult, parseResults: ParseFileResults) => DocumentRange, - convertToLocation?: (fs: ReadOnlyFileSystem, ranges: DocumentRange) => Location | undefined + convertToLocation?: ( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + ranges: DocumentRange + ) => Location | undefined ): Promise { if (this._pendingFindAllRefsCancellationSource) { this._pendingFindAllRefsCancellationSource.cancel(); @@ -813,6 +832,7 @@ export abstract class LanguageServerBase { return new ReferencesProvider( + this, program, source.token, createDocumentRange, @@ -843,7 +863,8 @@ export abstract class LanguageServerBase { - return new RenameProvider(program, uri, params.position, token).canRenameSymbol( + return new RenameProvider(program, uri, params.position, token, this).canRenameSymbol( workspace.kinds.includes(WellKnownWorkspaceKinds.Default), isUntitled ); @@ -1022,7 +1044,7 @@ export abstract class LanguageServerBase { - return new RenameProvider(program, uri, params.position, token).renameSymbol( + return new RenameProvider(program, uri, params.position, token, this).renameSymbol( params.newName, workspace.kinds.includes(WellKnownWorkspaceKinds.Default), isUntitled @@ -1031,7 +1053,7 @@ export abstract class LanguageServerBase { - const uri = Uri.parse(params.textDocument.uri, this.serviceProvider); + const uri = this.convertLspUriStringToUri(params.textDocument.uri); const workspace = await this.getWorkspaceForFile(uri); if ( workspace.disableLanguageServices || @@ -1068,7 +1090,7 @@ export abstract class LanguageServerBase { - const uri = Uri.parse(params.textDocument.uri, this.serviceProvider); + const uri = this.convertLspUriStringToUri(params.textDocument.uri); const workspace = await this.getWorkspaceForFile(uri); if (workspace.disableLanguageServices) { return { @@ -1093,7 +1115,7 @@ export abstract class LanguageServerBase { - return new CallHierarchyProvider(program, uri, params.position, token).onPrepare(); + return new CallHierarchyProvider(program, uri, params.position, token, this).onPrepare(); }, token); } @@ -1106,7 +1128,7 @@ export abstract class LanguageServerBase { - return new CallHierarchyProvider(program, uri, params.item.range.start, token).getIncomingCalls(); + return new CallHierarchyProvider(program, uri, params.item.range.start, token, this).getIncomingCalls(); }, token); } @@ -1122,24 +1144,12 @@ export abstract class LanguageServerBase { - return new CallHierarchyProvider(program, uri, params.item.range.start, token).getOutgoingCalls(); + return new CallHierarchyProvider(program, uri, params.item.range.start, token, this).getOutgoingCalls(); }, token); } - // these overloads ban specifying chainedFile unless iPythonMode is CellDocs - protected async onDidOpenTextDocument( - params: DidOpenTextDocumentParams, - iPythonMode: IPythonMode.CellDocs, - chainedFile?: TextDocumentItem - ): Promise; - protected async onDidOpenTextDocument(params: DidOpenTextDocumentParams, iPythonMode?: IPythonMode): Promise; - protected async onDidOpenTextDocument( - params: DidOpenTextDocumentParams, - iPythonMode = IPythonMode.None, - chainedFile?: TextDocumentItem - ) { + protected async onDidOpenTextDocument(params: DidOpenTextDocumentParams) { const uri = this.convertLspUriStringToUri(params.textDocument.uri); - let doc = this.openFileMap.get(uri.key); if (doc) { // We shouldn't get an open text document request for an already-opened doc. @@ -1154,31 +1164,39 @@ export abstract class LanguageServerBase { - w.service.setFileOpened( - uri, - params.textDocument.version, - params.textDocument.text, - iPythonMode, - chainedFileUri - ); + w.service.setFileOpened(uri, params.textDocument.version, params.textDocument.text); }); } protected onDidOpenNotebookDocument = async (params: DidOpenNotebookDocumentParams) => { + const uri = this.convertLspUriStringToUri(params.notebookDocument.uri); + const openCells: TextDocument[] = []; + this._openCells.set(uri.key, openCells); await Promise.all( - params.cellTextDocuments.map((textDocument, index) => - // the previous cell is the chained document - this.onDidOpenTextDocument({ textDocument }, IPythonMode.CellDocs, params.cellTextDocuments[index - 1]) - ) + params.cellTextDocuments.map(async (textDocument, index) => { + const cellUri = this.convertLspUriStringToUri(textDocument.uri, index); + const doc = TextDocument.create(textDocument.uri, 'python', textDocument.version, textDocument.text); + openCells.push(doc); + // Send this open to all the workspaces that might contain this file. + const workspaces = await this.getContainingWorkspacesForFile(cellUri); + workspaces.forEach((w) => { + w.service.setFileOpened( + cellUri, + textDocument.version, + textDocument.text, + IPythonMode.CellDocs, + this._getChainedFileUri(textDocument, index) + ); + }); + }) ); }; - protected async onDidChangeTextDocument(params: DidChangeTextDocumentParams, ipythonMode = IPythonMode.None) { + protected async onDidChangeTextDocument(params: DidChangeTextDocumentParams) { this.recordUserInteractionTime(); const uri = this.convertLspUriStringToUri(params.textDocument.uri); @@ -1195,68 +1213,123 @@ export abstract class LanguageServerBase { - w.service.updateOpenFileContents(uri, params.textDocument.version, newContents, ipythonMode); + w.service.updateOpenFileContents(uri, params.textDocument.version, newContents); }); } protected onDidChangeNotebookDocument = async (params: DidChangeNotebookDocumentParams) => { + this.recordUserInteractionTime(); + const uri = this.convertLspUriStringToUri(params.notebookDocument.uri); + const openCells = this._openCells.get(uri.key); + if (!openCells) { + this.console.error(`onDidChangeNotebookDocument failed to find open cells for ${uri}`); + return; + } const changeStructure = params.change.cells?.structure; - // open any new documents first before we action the cell changes, because that's where we attach - // chained documents and if we try to do that before the document is opened it won't work - if (changeStructure?.didOpen) { - await Promise.all( - changeStructure.didOpen.map((textDocument) => - this.onDidOpenTextDocument({ textDocument }, IPythonMode.CellDocs) - ) + if (changeStructure) { + const previousCells = [...openCells]; + openCells.splice( + changeStructure.array.start, + changeStructure.array.deleteCount, + ...(changeStructure.array.cells?.map((changedTextDocumentItem) => { + // if there isn't a cell at this index already, we need to open it + const newDocumentItem = changeStructure.didOpen?.find( + (newDocument) => newDocument.uri === changedTextDocumentItem.document + ); + if (newDocumentItem) { + return TextDocument.create( + newDocumentItem.uri, + 'python', + newDocumentItem.version, + newDocumentItem.text + ); + } else { + const result = openCells.find((openCell) => openCell.uri === changedTextDocumentItem.document); + if (!result) { + throw new Error(`failed to find existing cell ${changedTextDocumentItem.document}`); + } + return result; + } + }) ?? []) ); - } - // the rest of these methods can be executed in any order so we collect all their promises and await - // them at the end - const promises: Promise[] = []; - if (changeStructure?.didClose) { - promises.push( - ...changeStructure.didClose.map((textDocument) => this.onDidCloseTextDocument({ textDocument })) + await Promise.all( + zip(previousCells, openCells).map(async ([previousCell, newCell], index) => { + if (previousCell?.uri === newCell?.uri) { + return; + } + if (previousCell === undefined) { + // a new cell was added and we didn't already have one at this index so we need to open it + if (newCell === undefined) { + // this should never happen + throw new Error('new cell was undefined when new cell was added'); + } + const cellUri = this.convertLspUriStringToUri(newCell.uri, index); + // Send this open to all the workspaces that might contain this file. + const workspaces = await this.getContainingWorkspacesForFile(cellUri); + workspaces.forEach((w) => { + w.service.setFileOpened( + cellUri, + newCell.version, + newCell.getText(), + IPythonMode.CellDocs, + this._getChainedFileUri(newCell, index) + ); + }); + } else { + const cellUri = this.convertLspUriStringToUri(params.notebookDocument.uri, index); + const workspaces = await this.getContainingWorkspacesForFile(cellUri); + if (newCell === undefined) { + // a cell was deleted and there's no longer a cell at this index so we need to close it + // Send this close to all the workspaces that might contain this file. + workspaces.forEach((w) => w.service.setFileClosed(cellUri)); + } else { + // this index now has a different cell than it did before (ie. order was changed) so we have to update the already opened cell + // with the new text content + const newContents = newCell.getText(); + // Send this change to all the workspaces that might contain this file. + workspaces.forEach((w) => { + w.service.updateOpenFileContents( + cellUri, + previousCell.version + 1, + newContents, + IPythonMode.CellDocs + ); + w.service.updateChainedUri(cellUri, this._getChainedFileUri(newCell, index)); + }); + } + } + }) ); } - const cellChanges = changeStructure?.array.cells; - if (cellChanges) { - promises.push( - ...cellChanges.map(async (cell, index) => { - const uri = this.convertLspUriStringToUri(cell.document); - const doc = this.openFileMap.get(uri.key); + if (params.change.cells?.textContent) { + await Promise.all( + params.change.cells.textContent.map(async (textContent) => { + const cellUri = this.convertLspUriStringToUri(textContent.document.uri); + const doc = this._openCells.get(uri.key)?.find((cell) => cell.uri === textContent.document.uri); if (!doc) { // We shouldn't get a change text request for a closed doc. - this.console.error(`Received change notebook document command for closed cell ${uri}`); + this.console.error(`failed to find document for changed cell ${cellUri}`); return; } + TextDocument.update(doc, textContent.changes, textContent.document.version); + const newContents = doc.getText(); // Send this change to all the workspaces that might contain this file. - const workspaces = await this.getContainingWorkspacesForFile(uri); - const previousCell = cellChanges[index - 1]?.document; - const previousCellUri = previousCell ? Uri.parse(previousCell, this.serviceProvider) : undefined; - workspaces.forEach((w) => { - w.service.updateChainedUri(uri, previousCellUri); - }); + const workspaces = await this.getContainingWorkspacesForFile(cellUri); + workspaces.forEach((w) => + w.service.updateOpenFileContents( + cellUri, + textContent.document.version, + newContents, + IPythonMode.CellDocs + ) + ); }) ); } - if (params.change.cells?.textContent) { - promises.push( - ...params.change.cells.textContent.map((textContentChange) => - this.onDidChangeTextDocument( - { - textDocument: textContentChange.document, - contentChanges: textContentChange.changes, - }, - IPythonMode.CellDocs - ) - ) - ); - } - await Promise.all(promises); }; - protected async onDidCloseTextDocument(params: DidCloseTextDocumentParams) { - const uri = this.convertLspUriStringToUri(params.textDocument.uri); + protected async onDidCloseTextDocument(params: DidCloseTextDocumentParams, cellIndex?: number) { + const uri = this.convertLspUriStringToUri(params.textDocument.uri, cellIndex); // Send this close to all the workspaces that might contain this file. const workspaces = await this.getContainingWorkspacesForFile(uri); @@ -1268,8 +1341,21 @@ export abstract class LanguageServerBase { + const uri = this.convertLspUriStringToUri(params.notebookDocument.uri); + const openCells = this._openCells.get(uri.key); + if (!openCells) { + this.console.error(`onDidCloseNotebookDocument failed to find open cells for ${uri}`); + return; + } + openCells.length = 0; + this._openCells.delete(uri.key); await Promise.all( - params.cellTextDocuments.map((textDocument) => this.onDidCloseTextDocument({ textDocument })) + params.cellTextDocuments.map(async (textDocument) => { + const cellUri = this.convertLspUriStringToUri(textDocument.uri); + // Send this close to all the workspaces that might contain this file. + const workspaces = await this.getContainingWorkspacesForFile(cellUri); + workspaces.forEach((w) => w.service.setFileClosed(cellUri)); + }) ); }; @@ -1327,6 +1413,11 @@ export abstract class LanguageServerBase { + const uri = this.convertLspUriStringToUri(params.notebookDocument.uri); + this._openCells.get(uri.key)?.forEach((cell) => this.savedFilesForBaselineUpdate.add(cell.uri)); + }; + protected async onExecuteCommand( params: ExecuteCommandParams, token: CancellationToken, @@ -1389,6 +1480,7 @@ export abstract class LanguageServerBase { return { - uri: convertUriToLspUriString(fs, fileDiagnostics.fileUri), + uri: this.convertUriToLspUriString(fs, fileDiagnostics.fileUri), version: fileDiagnostics.version, diagnostics: this._convertDiagnostics(fs, fileDiagnostics.diagnostics), }; @@ -1429,7 +1521,7 @@ export abstract class LanguageServerBase(); for (const textDocumentUri of this.savedFilesForBaselineUpdate) { - const fileUri = Uri.file(textDocumentUri, this.serviceProvider); + const fileUri = this.convertLspUriStringToUri(textDocumentUri); // can't use result.diagnostics because we need the diagnostics from the previous analysis since // saves don't trigger checking (i think) const fileDiagnostics = this.documentsWithDiagnostics[fileUri.toString()]; @@ -1589,14 +1681,40 @@ export abstract class LanguageServerBase cell.uri === uri); + if (cellIndex === undefined) { + throw new Error(`failed to find cell index when converting uri ${uri}`); + } + } + // remove the vscode-notebook-cell:// scheme + result = result = Uri.file(parsedUri.getPath(), this.serviceProvider); + } else { + result = parsedUri; + } + if (cellIndex === undefined) { + return result; + } + return result.withFragment(cellIndex.toString()); + } + + /** + * if cellIndex > 0 then it has a chained file, which should always just be the previous index because we update them + * on every change + */ + private _getChainedFileUri = (cell: { uri: string }, index: number) => + index ? this.convertLspUriStringToUri(cell.uri, index - 1) : undefined; + private _getCompatibleMarkupKind(clientSupportedFormats: MarkupKind[] | undefined) { const serverSupportedFormats = [MarkupKind.PlainText, MarkupKind.Markdown]; @@ -1660,7 +1778,7 @@ export abstract class LanguageServerBase this.canNavigateToFile(info.uri, fs)) .map((info) => DiagnosticRelatedInformation.create( - Location.create(convertUriToLspUriString(fs, info.uri), info.range), + Location.create(this.convertUriToLspUriString(fs, info.uri), info.range), info.message ) ); diff --git a/packages/pyright-internal/src/languageService/callHierarchyProvider.ts b/packages/pyright-internal/src/languageService/callHierarchyProvider.ts index d67e7e4384..fffd959599 100644 --- a/packages/pyright-internal/src/languageService/callHierarchyProvider.ts +++ b/packages/pyright-internal/src/languageService/callHierarchyProvider.ts @@ -34,12 +34,12 @@ import { convertOffsetsToRange } from '../common/positionUtils'; import { ServiceKeys } from '../common/serviceKeys'; import { Position, rangesAreEqual } from '../common/textRange'; import { Uri } from '../common/uri/uri'; -import { convertUriToLspUriString } from '../common/uri/uriUtils'; import { ReferencesProvider, ReferencesResult } from '../languageService/referencesProvider'; import { CallNode, MemberAccessNode, NameNode, ParseNode, ParseNodeType } from '../parser/parseNodes'; import { ParseFileResults } from '../parser/parser'; import { DocumentSymbolCollector } from './documentSymbolCollector'; import { canNavigateToFile } from './navigationUtils'; +import { LanguageServerInterface } from '../common/languageServerInterface'; export class CallHierarchyProvider { private readonly _parseResults: ParseFileResults | undefined; @@ -48,7 +48,8 @@ export class CallHierarchyProvider { private _program: ProgramView, private _fileUri: Uri, private _position: Position, - private _token: CancellationToken + private _token: CancellationToken, + private _ls: LanguageServerInterface ) { this._parseResults = this._program.getParseResults(this._fileUri); } @@ -88,7 +89,7 @@ export class CallHierarchyProvider { const callItem: CallHierarchyItem = { name: symbolName, kind: getSymbolKind(targetDecl, this._evaluator, symbolName) ?? SymbolKind.Module, - uri: convertUriToLspUriString(this._program.fileSystem, callItemUri), + uri: this._ls.convertUriToLspUriString(this._program.fileSystem, callItemUri), range: targetDecl.range, selectionRange: targetDecl.range, }; @@ -202,7 +203,8 @@ export class CallHierarchyProvider { parseRoot, this._parseResults, this._evaluator, - this._token + this._token, + this._ls ); const outgoingCalls = callFinder.findCalls(); if (outgoingCalls.length === 0) { @@ -267,7 +269,14 @@ export class CallHierarchyProvider { ): CallHierarchyIncomingCall[] | undefined { throwIfCancellationRequested(this._token); - const callFinder = new FindIncomingCallTreeWalker(this._program, fileUri, symbolName, declaration, this._token); + const callFinder = new FindIncomingCallTreeWalker( + this._program, + fileUri, + symbolName, + declaration, + this._token, + this._ls + ); const incomingCalls = callFinder.findCalls(); return incomingCalls.length > 0 ? incomingCalls : undefined; @@ -293,7 +302,8 @@ class FindOutgoingCallTreeWalker extends ParseTreeWalker { private _parseRoot: ParseNode, private _parseResults: ParseFileResults, private _evaluator: TypeEvaluator, - private _cancellationToken: CancellationToken + private _cancellationToken: CancellationToken, + private _ls: LanguageServerInterface ) { super(); } @@ -384,7 +394,7 @@ class FindOutgoingCallTreeWalker extends ParseTreeWalker { const callDest: CallHierarchyItem = { name: nameNode.d.value, kind: getSymbolKind(resolvedDecl, this._evaluator, nameNode.d.value) ?? SymbolKind.Module, - uri: convertUriToLspUriString(this._fs, resolvedDecl.uri), + uri: this._ls.convertUriToLspUriString(this._fs, resolvedDecl.uri), range: resolvedDecl.range, selectionRange: resolvedDecl.range, }; @@ -430,7 +440,8 @@ class FindIncomingCallTreeWalker extends ParseTreeWalker { private readonly _fileUri: Uri, private readonly _symbolName: string, private readonly _targetDeclaration: Declaration, - private readonly _cancellationToken: CancellationToken + private readonly _cancellationToken: CancellationToken, + private _ls: LanguageServerInterface ) { super(); @@ -569,7 +580,7 @@ class FindIncomingCallTreeWalker extends ParseTreeWalker { callSource = { name: `(module) ${fileName}`, kind: SymbolKind.Module, - uri: convertUriToLspUriString(this._program.fileSystem, this._fileUri), + uri: this._ls.convertUriToLspUriString(this._program.fileSystem, this._fileUri), range: moduleRange, selectionRange: moduleRange, }; @@ -583,7 +594,7 @@ class FindIncomingCallTreeWalker extends ParseTreeWalker { callSource = { name: '(lambda)', kind: SymbolKind.Function, - uri: convertUriToLspUriString(this._program.fileSystem, this._fileUri), + uri: this._ls.convertUriToLspUriString(this._program.fileSystem, this._fileUri), range: lambdaRange, selectionRange: lambdaRange, }; @@ -597,7 +608,7 @@ class FindIncomingCallTreeWalker extends ParseTreeWalker { callSource = { name: executionNode.d.name.d.value, kind: SymbolKind.Function, - uri: convertUriToLspUriString(this._program.fileSystem, this._fileUri), + uri: this._ls.convertUriToLspUriString(this._program.fileSystem, this._fileUri), range: functionRange, selectionRange: functionRange, }; diff --git a/packages/pyright-internal/src/languageService/codeActionProvider.ts b/packages/pyright-internal/src/languageService/codeActionProvider.ts index 4672d60320..298ffa2e86 100644 --- a/packages/pyright-internal/src/languageService/codeActionProvider.ts +++ b/packages/pyright-internal/src/languageService/codeActionProvider.ts @@ -23,6 +23,7 @@ import { convertOffsetToPosition, convertPositionToOffset } from '../common/posi import { findNodeByOffset } from '../analyzer/parseTreeUtils'; import { ParseNodeType } from '../parser/parseNodes'; import { sorter } from '../common/collectionUtils'; +import { LanguageServerInterface } from '../common/languageServerInterface'; export class CodeActionProvider { static mightSupport(kinds: CodeActionKind[] | undefined): boolean { @@ -39,7 +40,8 @@ export class CodeActionProvider { fileUri: Uri, range: Range, kinds: CodeActionKind[] | undefined, - token: CancellationToken + token: CancellationToken, + ls: LanguageServerInterface ) { throwIfCancellationRequested(token); @@ -113,6 +115,7 @@ export class CodeActionProvider { continue; } const workspaceEdit = convertToWorkspaceEdit( + ls, completer.importResolver.fileSystem, convertToFileTextEdits(fileUri, convertToTextEditActions(textEdits)) ); @@ -178,7 +181,7 @@ export class CodeActionProvider { }, ], }; - const workspaceEdit = convertToWorkspaceEdit(workspace.service.fs, editActions); + const workspaceEdit = convertToWorkspaceEdit(ls, workspace.service.fs, editActions); const renameAction = CodeAction.create(title, workspaceEdit, CodeActionKind.QuickFix); codeActions.push(renameAction); } diff --git a/packages/pyright-internal/src/languageService/documentSymbolProvider.ts b/packages/pyright-internal/src/languageService/documentSymbolProvider.ts index 203252e4a2..f172de5a58 100644 --- a/packages/pyright-internal/src/languageService/documentSymbolProvider.ts +++ b/packages/pyright-internal/src/languageService/documentSymbolProvider.ts @@ -15,19 +15,20 @@ import { throwIfCancellationRequested } from '../common/cancellationUtils'; import { ProgramView } from '../common/extensibility'; import { ReadOnlyFileSystem } from '../common/fileSystem'; import { Uri } from '../common/uri/uri'; -import { convertUriToLspUriString } from '../common/uri/uriUtils'; import { ParseFileResults } from '../parser/parser'; import { IndexOptions, IndexSymbolData, SymbolIndexer } from './symbolIndexer'; +import { LanguageServerInterface } from '../common/languageServerInterface'; export function convertToFlatSymbols( program: ProgramView, uri: Uri, - symbolList: DocumentSymbol[] + symbolList: DocumentSymbol[], + ls: LanguageServerInterface ): SymbolInformation[] { const flatSymbols: SymbolInformation[] = []; for (const symbol of symbolList) { - _appendToFlatSymbolsRecursive(program.fileSystem, flatSymbols, uri, symbol); + _appendToFlatSymbolsRecursive(ls, program.fileSystem, flatSymbols, uri, symbol); } return flatSymbols; @@ -41,7 +42,8 @@ export class DocumentSymbolProvider { protected readonly uri: Uri, private readonly _supportHierarchicalDocumentSymbol: boolean, private readonly _indexOptions: IndexOptions, - private readonly _token: CancellationToken + private readonly _token: CancellationToken, + private readonly _ls: LanguageServerInterface ) { this._parseResults = this.program.getParseResults(this.uri); } @@ -56,7 +58,7 @@ export class DocumentSymbolProvider { return symbolList; } - return convertToFlatSymbols(this.program, this.uri, symbolList); + return convertToFlatSymbols(this.program, this.uri, symbolList, this._ls); } protected getHierarchicalSymbols() { @@ -116,6 +118,7 @@ export class DocumentSymbolProvider { } function _appendToFlatSymbolsRecursive( + ls: LanguageServerInterface, fs: ReadOnlyFileSystem, flatSymbols: SymbolInformation[], documentUri: Uri, @@ -125,7 +128,7 @@ function _appendToFlatSymbolsRecursive( const flatSymbol: SymbolInformation = { name: symbol.name, kind: symbol.kind, - location: Location.create(convertUriToLspUriString(fs, documentUri), symbol.range), + location: Location.create(ls.convertUriToLspUriString(fs, documentUri), symbol.range), }; if (symbol.tags) { @@ -140,7 +143,7 @@ function _appendToFlatSymbolsRecursive( if (symbol.children) { for (const child of symbol.children) { - _appendToFlatSymbolsRecursive(fs, flatSymbols, documentUri, child, symbol); + _appendToFlatSymbolsRecursive(ls, fs, flatSymbols, documentUri, child, symbol); } } } diff --git a/packages/pyright-internal/src/languageService/navigationUtils.ts b/packages/pyright-internal/src/languageService/navigationUtils.ts index 3081721a25..01fe45f502 100644 --- a/packages/pyright-internal/src/languageService/navigationUtils.ts +++ b/packages/pyright-internal/src/languageService/navigationUtils.ts @@ -9,24 +9,33 @@ import { Location } from 'vscode-languageserver-types'; import { ReadOnlyFileSystem } from '../common/fileSystem'; import { DocumentRange } from '../common/textRange'; import { Uri } from '../common/uri/uri'; -import { convertUriToLspUriString } from '../common/uri/uriUtils'; +import { LanguageServerInterface } from '../common/languageServerInterface'; export function canNavigateToFile(fs: ReadOnlyFileSystem, path: Uri): boolean { return !fs.isInZip(path); } export function convertDocumentRangesToLocation( + ls: LanguageServerInterface, fs: ReadOnlyFileSystem, ranges: DocumentRange[], - converter: (fs: ReadOnlyFileSystem, range: DocumentRange) => Location | undefined = convertDocumentRangeToLocation + converter: ( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + range: DocumentRange + ) => Location | undefined = convertDocumentRangeToLocation ): Location[] { - return ranges.map((range) => converter(fs, range)).filter((loc) => !!loc) as Location[]; + return ranges.map((range) => converter(ls, fs, range)).filter((loc) => !!loc) as Location[]; } -export function convertDocumentRangeToLocation(fs: ReadOnlyFileSystem, range: DocumentRange): Location | undefined { +export function convertDocumentRangeToLocation( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + range: DocumentRange +): Location | undefined { if (!canNavigateToFile(fs, range.uri)) { return undefined; } - return Location.create(convertUriToLspUriString(fs, range.uri), range.range); + return Location.create(ls.convertUriToLspUriString(fs, range.uri), range.range); } diff --git a/packages/pyright-internal/src/languageService/referencesProvider.ts b/packages/pyright-internal/src/languageService/referencesProvider.ts index f4796c3c85..734c0f6b85 100644 --- a/packages/pyright-internal/src/languageService/referencesProvider.ts +++ b/packages/pyright-internal/src/languageService/referencesProvider.ts @@ -32,6 +32,7 @@ import { NameNode, ParseNode, ParseNodeType } from '../parser/parseNodes'; import { ParseFileResults } from '../parser/parser'; import { CollectionResult, DocumentSymbolCollector } from './documentSymbolCollector'; import { convertDocumentRangesToLocation } from './navigationUtils'; +import { LanguageServerInterface } from '../common/languageServerInterface'; export type ReferenceCallback = (locations: DocumentRange[]) => void; @@ -185,6 +186,7 @@ export class FindReferencesTreeWalker { export class ReferencesProvider { constructor( + private _ls: LanguageServerInterface, private _program: ProgramView, private _token: CancellationToken, private readonly _createDocumentRange?: ( @@ -192,7 +194,11 @@ export class ReferencesProvider { result: CollectionResult, parseResults: ParseFileResults ) => DocumentRange, - private readonly _convertToLocation?: (fs: ReadOnlyFileSystem, ranges: DocumentRange) => Location | undefined + private readonly _convertToLocation?: ( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + ranges: DocumentRange + ) => Location | undefined ) { // empty } @@ -217,12 +223,22 @@ export class ReferencesProvider { const reporter: ReferenceCallback = resultReporter ? (range) => resultReporter.report( - convertDocumentRangesToLocation(this._program.fileSystem, range, this._convertToLocation) + convertDocumentRangesToLocation( + this._ls, + this._program.fileSystem, + range, + this._convertToLocation + ) ) : (range) => appendArray( locations, - convertDocumentRangesToLocation(this._program.fileSystem, range, this._convertToLocation) + convertDocumentRangesToLocation( + this._ls, + this._program.fileSystem, + range, + this._convertToLocation + ) ); const invokedFromUserFile = isUserCode(sourceFileInfo); diff --git a/packages/pyright-internal/src/languageService/renameProvider.ts b/packages/pyright-internal/src/languageService/renameProvider.ts index db96a13772..cd7de1c1a9 100644 --- a/packages/pyright-internal/src/languageService/renameProvider.ts +++ b/packages/pyright-internal/src/languageService/renameProvider.ts @@ -22,6 +22,7 @@ import { ReferencesProvider, ReferencesResult } from '../languageService/referen import { ParseNodeType } from '../parser/parseNodes'; import { ParseFileResults } from '../parser/parser'; import { IPythonMode } from '../analyzer/sourceFile'; +import { LanguageServerInterface } from '../common/languageServerInterface'; export class RenameProvider { private readonly _parseResults: ParseFileResults | undefined; @@ -30,7 +31,8 @@ export class RenameProvider { private _program: ProgramView, private _fileUri: Uri, private _position: Position, - private _token: CancellationToken + private _token: CancellationToken, + private _ls: LanguageServerInterface ) { this._parseResults = this._program.getParseResults(this._fileUri); } @@ -72,7 +74,7 @@ export class RenameProvider { return null; } - const referenceProvider = new ReferencesProvider(this._program, this._token); + const referenceProvider = new ReferencesProvider(this._ls, this._program, this._token); const renameMode = RenameProvider.getRenameSymbolMode( this._program, this._fileUri, @@ -92,13 +94,9 @@ export class RenameProvider { // from accidentally changing third party library or type stub. if (isUserCode(curSourceFileInfo)) { // Make sure searching symbol name exists in the file. - // TODO: why is this here? source files shouldnt be read from disk directly when using the language server. - // for now we just disable this check in notebooks because they use a different file uri in the lsp - if (curSourceFileInfo.sourceFile.getIPythonMode() !== IPythonMode.CellDocs) { - const content = curSourceFileInfo.sourceFile.getFileContent() ?? ''; - if (!referencesResult.symbolNames.some((s) => content.search(s) >= 0)) { - continue; - } + const content = curSourceFileInfo.sourceFile.getFileContent() ?? ''; + if (!referencesResult.symbolNames.some((s) => content.search(s) >= 0)) { + continue; } referenceProvider.addReferencesToResult( @@ -152,7 +150,7 @@ export class RenameProvider { }); }); - return convertToWorkspaceEdit(this._program.fileSystem, { edits, fileOperations: [] }); + return convertToWorkspaceEdit(this._ls, this._program.fileSystem, { edits, fileOperations: [] }); } static getRenameSymbolMode( diff --git a/packages/pyright-internal/src/languageService/workspaceSymbolProvider.ts b/packages/pyright-internal/src/languageService/workspaceSymbolProvider.ts index 558cececb4..c392f618d3 100644 --- a/packages/pyright-internal/src/languageService/workspaceSymbolProvider.ts +++ b/packages/pyright-internal/src/languageService/workspaceSymbolProvider.ts @@ -14,9 +14,9 @@ import { appendArray } from '../common/collectionUtils'; import { ProgramView } from '../common/extensibility'; import * as StringUtils from '../common/stringUtils'; import { Uri } from '../common/uri/uri'; -import { convertUriToLspUriString } from '../common/uri/uriUtils'; import { Workspace } from '../workspaceFactory'; import { IndexSymbolData, SymbolIndexer } from './symbolIndexer'; +import { LanguageServerInterface } from '../common/languageServerInterface'; type WorkspaceSymbolCallback = (symbols: SymbolInformation[]) => void; @@ -28,7 +28,8 @@ export class WorkspaceSymbolProvider { private readonly _workspaces: Workspace[], resultReporter: ResultProgressReporter | undefined, private readonly _query: string, - private readonly _token: CancellationToken + private readonly _token: CancellationToken, + private _ls: LanguageServerInterface ) { this._reporter = resultReporter ? (symbols) => resultReporter.report(symbols) @@ -100,7 +101,7 @@ export class WorkspaceSymbolProvider { if (StringUtils.isPatternInSymbol(this._query, symbolData.name)) { const location: Location = { - uri: convertUriToLspUriString(program.fileSystem, fileUri), + uri: this._ls.convertUriToLspUriString(program.fileSystem, fileUri), range: symbolData.selectionRange!, }; diff --git a/packages/pyright-internal/src/realLanguageServer.ts b/packages/pyright-internal/src/realLanguageServer.ts index 2b175c9256..af3af5a851 100644 --- a/packages/pyright-internal/src/realLanguageServer.ts +++ b/packages/pyright-internal/src/realLanguageServer.ts @@ -280,7 +280,14 @@ export abstract class RealLanguageServer< const uri = Uri.parse(params.textDocument.uri, this.serverOptions.serviceProvider); const workspace = await this.getWorkspaceForFile(uri); - return CodeActionProvider.getCodeActionsForPosition(workspace, uri, params.range, params.context.only, token); + return CodeActionProvider.getCodeActionsForPosition( + workspace, + uri, + params.range, + params.context.only, + token, + this + ); } protected createProgressReporter(): ProgressReporter { diff --git a/packages/pyright-internal/src/tests/harness/fourslash/testLanguageService.ts b/packages/pyright-internal/src/tests/harness/fourslash/testLanguageService.ts index 6dd4308a01..92a66c9cab 100644 --- a/packages/pyright-internal/src/tests/harness/fourslash/testLanguageService.ts +++ b/packages/pyright-internal/src/tests/harness/fourslash/testLanguageService.ts @@ -20,7 +20,7 @@ import { CommandController } from '../../../commands/commandController'; import { ConfigOptions } from '../../../common/configOptions'; import { ConsoleInterface } from '../../../common/console'; import * as debug from '../../../common/debug'; -import { FileSystem } from '../../../common/fileSystem'; +import { FileSystem, ReadOnlyFileSystem } from '../../../common/fileSystem'; import { ServiceProvider } from '../../../common/serviceProvider'; import { Range } from '../../../common/textRange'; import { Uri } from '../../../common/uri/uri'; @@ -57,12 +57,13 @@ export class TestFeatures implements HostSpecificFeatures { ); getCodeActionsForPosition( + ls: LanguageServerInterface, workspace: Workspace, fileUri: Uri, range: Range, token: CancellationToken ): Promise { - return CodeActionProvider.getCodeActionsForPosition(workspace, fileUri, range, undefined, token); + return CodeActionProvider.getCodeActionsForPosition(workspace, fileUri, range, undefined, token, ls); } execute(ls: LanguageServerInterface, params: ExecuteCommandParams, token: CancellationToken): Promise { const controller = new CommandController(ls); @@ -111,6 +112,8 @@ export class TestLanguageService implements LanguageServerInterface { searchPathsToWatch: [], }; } + /** unlike the real one, this test implementation doesn't support notebook cells. TODO: language server tests for notebook cells */ + convertUriToLspUriString = (fs: ReadOnlyFileSystem, uri: Uri) => fs.getOriginalUri(uri).toString(); getWorkspaces(): Promise { return Promise.resolve([this._workspace, this._defaultWorkspace]); diff --git a/packages/pyright-internal/src/tests/harness/fourslash/testState.ts b/packages/pyright-internal/src/tests/harness/fourslash/testState.ts index 8a6d9555a0..887615a9c6 100644 --- a/packages/pyright-internal/src/tests/harness/fourslash/testState.ts +++ b/packages/pyright-internal/src/tests/harness/fourslash/testState.ts @@ -110,6 +110,7 @@ export interface HostSpecificFeatures { backgroundAnalysisProgramFactory: BackgroundAnalysisProgramFactory; getCodeActionsForPosition( + ls: LanguageServerInterface, workspace: Workspace, fileUri: Uri, range: PositionRange, @@ -699,6 +700,8 @@ export class TestState { } } + const ls = new TestLanguageService(this.workspace, this.console, this.fs); + // Local copy to use in capture. const serviceProvider = this.serviceProvider; for (const range of this.getRanges()) { @@ -714,7 +717,7 @@ export class TestState { } const diagnostics = sourceFile.getDiagnostics(this.configOptions) || []; - const codeActions = await this._getCodeActions(range); + const codeActions = await this._getCodeActions(range, ls); if (verifyMode === 'exact') { if (codeActions.length !== map[name].codeActions.length) { this.raiseError( @@ -856,7 +859,7 @@ export class TestState { const ls = new TestLanguageService(this.workspace, this.console, this.fs); - const codeActions = await this._getCodeActions(range); + const codeActions = await this._getCodeActions(range, ls); if (verifyCodeActionCount) { if (codeActions.length !== Object.keys(map).length) { this.raiseError( @@ -1200,10 +1203,14 @@ export class TestState { }; }, createDocumentRange?: (fileUri: Uri, result: CollectionResult, parseResults: ParseFileResults) => DocumentRange, - convertToLocation?: (fs: ReadOnlyFileSystem, ranges: DocumentRange) => Location | undefined + convertToLocation?: ( + ls: LanguageServerInterface, + fs: ReadOnlyFileSystem, + ranges: DocumentRange + ) => Location | undefined ) { this.analyze(); - + const ls = new TestLanguageService(this.workspace, this.console, this.fs); for (const name of this.getMarkerNames()) { const marker = this.getMarkerByName(name); const fileName = marker.fileName; @@ -1223,6 +1230,7 @@ export class TestState { const position = this.convertOffsetToPosition(fileName, marker.position); const actual = new ReferencesProvider( + ls, this.program, CancellationToken.None, createDocumentRange, @@ -1230,7 +1238,7 @@ export class TestState { ).reportReferences(Uri.file(fileName, this.serviceProvider), position, /* includeDeclaration */ true); assert.strictEqual(actual?.length ?? 0, expected.length, `${name} has failed`); - for (const r of convertDocumentRangesToLocation(this.program.fileSystem, expected, convertToLocation)) { + for (const r of convertDocumentRangesToLocation(ls, this.program.fileSystem, expected, convertToLocation)) { assert.equal(actual?.filter((d) => this._deepEqual(d, r)).length, 1); } } @@ -1242,7 +1250,7 @@ export class TestState { }; }) { this.analyze(); - + const ls = new TestLanguageService(this.workspace, this.console, this.fs); for (const marker of this.getMarkers()) { const fileName = marker.fileName; const name = this.getMarkerName(marker); @@ -1260,7 +1268,8 @@ export class TestState { this.program, Uri.file(fileName, this.serviceProvider), position, - CancellationToken.None + CancellationToken.None, + ls ).getIncomingCalls(); assert.strictEqual(actual?.length ?? 0, expectedFilePath.length, `${name} has failed`); @@ -1287,7 +1296,7 @@ export class TestState { }; }) { this.analyze(); - + const ls = new TestLanguageService(this.workspace, this.console, this.fs); for (const marker of this.getMarkers()) { const fileName = marker.fileName; const name = this.getMarkerName(marker); @@ -1305,7 +1314,8 @@ export class TestState { this.program, Uri.file(fileName, this.serviceProvider), position, - CancellationToken.None + CancellationToken.None, + ls ).getOutgoingCalls(); assert.strictEqual(actual?.length ?? 0, expectedFilePath.length, `${name} has failed`); @@ -1477,7 +1487,7 @@ export class TestState { isUntitled = false ) { this.analyze(); - + const ls = new TestLanguageService(this.workspace, this.console, this.fs); for (const marker of this.getMarkers()) { const fileName = marker.fileName; const name = this.getMarkerName(marker); @@ -1501,11 +1511,12 @@ export class TestState { ? Uri.parse(`untitled:${fileName.replace(/\\/g, '/')}`, this.serviceProvider) : Uri.file(fileName, this.serviceProvider), position, - CancellationToken.None + CancellationToken.None, + ls ).renameSymbol(expected.newName, /* isDefaultWorkspace */ false, isUntitled); verifyWorkspaceEdit( - convertToWorkspaceEdit(this.program.fileSystem, { edits: expected.changes, fileOperations: [] }), + convertToWorkspaceEdit(ls, this.program.fileSystem, { edits: expected.changes, fileOperations: [] }), actual ?? { documentChanges: [] } ); } @@ -2016,7 +2027,7 @@ export class TestState { } } - private _getCodeActions(range: Range) { + private _getCodeActions(range: Range, ls: LanguageServerInterface) { const file = range.fileName; const textRange = { start: this.convertOffsetToPosition(file, range.pos), @@ -2024,6 +2035,7 @@ export class TestState { }; return this._hostSpecificFeatures.getCodeActionsForPosition( + ls, this.workspace, range.fileUri, textRange, diff --git a/packages/pyright-internal/src/tests/sourceFile.test.ts b/packages/pyright-internal/src/tests/sourceFile.test.ts index c7f38773a1..e4cfadebd3 100644 --- a/packages/pyright-internal/src/tests/sourceFile.test.ts +++ b/packages/pyright-internal/src/tests/sourceFile.test.ts @@ -34,7 +34,8 @@ test('Empty', () => { { isEditMode: false, }, - new BaselineHandler(fs, configOptions, undefined) + new BaselineHandler(fs, configOptions, undefined), + () => undefined ); const sp = createServiceProvider(fs); const importResolver = new ImportResolver(sp, configOptions, new FullAccessHost(sp)); diff --git a/packages/pyright-internal/src/tests/workspaceEditUtils.test.ts b/packages/pyright-internal/src/tests/workspaceEditUtils.test.ts index 56f59c5d99..3563360925 100644 --- a/packages/pyright-internal/src/tests/workspaceEditUtils.test.ts +++ b/packages/pyright-internal/src/tests/workspaceEditUtils.test.ts @@ -295,7 +295,13 @@ test('test generateWorkspaceEdits', async () => { assert.strictEqual(fileChanged.size, 2); - const actualEdits = generateWorkspaceEdit(state.workspace.service.fs, state.workspace.service, cloned, fileChanged); + const actualEdits = generateWorkspaceEdit( + new TestLanguageService(state.workspace, state.console, state.workspace.service.fs), + state.workspace.service.fs, + state.workspace.service, + cloned, + fileChanged + ); verifyWorkspaceEdit( { changes: {