Skip to content

Commit

Permalink
Update tree-sitter-parser for tree-sitter 0.23 and latest grammar
Browse files Browse the repository at this point in the history
  • Loading branch information
ZedThree committed Sep 18, 2024
1 parent 567f6ad commit 073d93a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 38 deletions.
53 changes: 23 additions & 30 deletions ford/tree_sitter_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@
from typing import List, Optional
from contextlib import contextmanager

try:
import tree_sitter_languages
from tree_sitter import Node, Language, TreeCursor
from tree_sitter import Node, Language, TreeCursor, Parser
import tree_sitter_fortran

FREE_FORM_LANGUAGE = tree_sitter_languages.get_language("fortran")
FIXED_FORM_LANGUAGE = tree_sitter_languages.get_language("fixed_form_fortran")
FREE_FORM_LANGUAGE = Language(tree_sitter_fortran.language())
# FIXME: use correct language when fixed form is on PyPI
FIXED_FORM_LANGUAGE = Language(tree_sitter_fortran.language())

FREE_FORM_PARSER = tree_sitter_languages.get_parser("fortran")
FIXED_FORM_PARSER = tree_sitter_languages.get_parser("fixed_form_fortran")

tree_sitter_parser_available = True
except ImportError:
tree_sitter_parser_available = False
FREE_FORM_PARSER = Parser(FREE_FORM_LANGUAGE)
FIXED_FORM_PARSER = Parser(FREE_FORM_LANGUAGE)


def is_predoc_comment(comment: str, predocmark: str, predocmar_alt: str) -> bool:
Expand All @@ -51,7 +47,11 @@ def __init__(self, language: Language, query: str):
self.query = language.query(query)

def __call__(self, node: Node) -> List[Node]:
return [capture for capture, _ in self.query.captures(node)]
# This is pretty crude, just collecting all the captures
captures = []
for capture in self.query.captures(node).values():
captures.extend(capture)
return captures

def maybe_first(self, node: Node) -> Optional[Node]:
if capture := self(node):
Expand Down Expand Up @@ -254,7 +254,8 @@ def parse_function(self, parent: FortranContainer, node: Node) -> FortranFunctio
attributes = self.function_attributes_query(node)
name = self._get_name(node)
arguments = self.procedure_parameters_query.first(node)
result = self.function_result_query.maybe_first(node)
result_name = self.function_result_query.maybe_first(node)
result_type = self.function_result_type_query.maybe_first(node)
bind = self.language_binding_query.maybe_first(node)

function = FortranFunction(
Expand All @@ -263,16 +264,12 @@ def parse_function(self, parent: FortranContainer, node: Node) -> FortranFunctio
parent=parent,
name=name,
attributes="\n".join(decode(attr) for attr in attributes),
result=maybe_decode(result),
arguments=decode(arguments),
result_name=maybe_decode(result_name),
result_type=maybe_decode(result_type),
arguments=[decode(argument) for argument in arguments.named_children],
bindC=maybe_decode(bind),
)

if result_type := self.function_result_type_query.maybe_first(node):
function.retvar = FortranVariable(
name=function.retvar, parent=function, vartype=decode(result_type)
)

function._cleanup()
return function

Expand Down Expand Up @@ -502,9 +499,9 @@ def parse_function(
self, parent: FortranContainer, cursor: TreeCursor
) -> FortranFunction:
attributes = []
result = None
result_name = None
result_type = None
arguments = ""
arguments = []
bindC = None

with descend_one_node(cursor):
Expand All @@ -514,9 +511,9 @@ def parse_function(
elif cursor.node.type in ("intrinsic_type", "derived_type"):
result_type = decode(cursor.node)
elif cursor.node.type == "parameters":
arguments = decode(cursor.node)
arguments = [decode(arg) for arg in cursor.node.named_children]
elif cursor.node.type == "function_result":
result = decode(cursor.node.named_children[0])
result_name = decode(cursor.node.named_children[0])
elif cursor.node.type == "language_binding":
bindC = decode(cursor.node)

Expand All @@ -529,16 +526,12 @@ def parse_function(
parent=parent,
name=self._get_name(cursor),
attributes=",".join(attributes),
result=result,
result_name=result_name,
result_type=result_type,
arguments=arguments,
bindC=bindC,
)

if result_type is not None:
function.retvar = FortranVariable(
name=function.retvar, parent=function, vartype=result_type
)

function._cleanup()
return function

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ dependencies = [
"graphviz ~= 0.20.0",
"tqdm ~= 4.64.0",
"tomli >= 1.1.0 ; python_version < '3.11'",
"tree-sitter >= 0.23.0",
"tree-sitter-fortran",
"rich >= 12.0.0",
"pcpp >= 1.30",
]
Expand Down
9 changes: 1 addition & 8 deletions test/test_tree_sitter_parser.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from ford.tree_sitter_parser import (
tree_sitter_parser_available,
TreeSitterParser,
TreeSitterCursorParser,
)
from ford.tree_sitter_parser import TreeSitterParser, TreeSitterCursorParser
from ford.sourceform import FortranSourceFile
from ford.settings import ProjectSettings

from textwrap import dedent

import time
import pytest


data = dedent(
Expand Down Expand Up @@ -42,7 +37,6 @@
).encode()


@pytest.mark.skipif(not tree_sitter_parser_available, reason="Requires tree-sitter")
def test_tree_sitter_parser():
parser = TreeSitterParser()
tree = parser.parser.parse(data)
Expand Down Expand Up @@ -75,7 +69,6 @@ def test_tree_sitter_parser():
assert len(function.functions) == 1


@pytest.mark.skipif(not tree_sitter_parser_available, reason="Requires tree-sitter")
def test_tree_sitter_cursor_parser():
parser = TreeSitterCursorParser()
tree = parser.parser.parse(data)
Expand Down

0 comments on commit 073d93a

Please sign in to comment.