Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Apr 22, 2024
1 parent 48a3ab1 commit ba77632
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 18 deletions.
16 changes: 4 additions & 12 deletions src/gt4py/cartesian/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,14 @@ def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]:
base_oir = GTIRToOIR().visit(stencil_ir)
oir_pipeline = self.backend.builder.options.backend_opts.get(
"oir_pipeline",
DefaultPipeline(
skip=[NoFieldAccessPruning], add_steps=[FillFlushToLocalKCaches]
),
DefaultPipeline(skip=[NoFieldAccessPruning], add_steps=[FillFlushToLocalKCaches]),
)
oir_node = oir_pipeline.run(base_oir)
cuir_node = OIRToCUIR().visit(oir_node)
cuir_node = kernel_fusion.FuseKernels().visit(cuir_node)
cuir_node = extent_analysis.CacheExtents().visit(cuir_node)
format_source = self.backend.builder.options.format_source
implementation = cuir_codegen.CUIRCodegen.apply(
cuir_node, format_source=format_source
)
implementation = cuir_codegen.CUIRCodegen.apply(cuir_node, format_source=format_source)
bindings = CudaBindingsCodegen.apply_codegen(
cuir_node,
module_name=self.module_name,
Expand Down Expand Up @@ -109,13 +105,9 @@ def visit_FieldDecl(self, node: cuir.FieldDecl, **kwargs):
def visit_ScalarDecl(self, node: cuir.ScalarDecl, **kwargs):
if "external_arg" in kwargs:
if kwargs["external_arg"]:
return "{dtype} {name}".format(
name=node.name, dtype=self.visit(node.dtype)
)
return "{dtype} {name}".format(name=node.name, dtype=self.visit(node.dtype))
else:
return "gridtools::stencil::global_parameter({name})".format(
name=node.name
)
return "gridtools::stencil::global_parameter({name})".format(name=node.name)

def visit_Program(self, node: cuir.Program, **kwargs):
assert "module_name" in kwargs
Expand Down
8 changes: 2 additions & 6 deletions src/gt4py/cartesian/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@
"extra_compile_args": {"cxx": extra_compile_args, "cuda": extra_compile_args},
"extra_link_args": [],
"parallel_jobs": multiprocessing.cpu_count(),
"cpp_template_depth": os.environ.get(
"GT_CPP_TEMPLATE_DEPTH", GT_CPP_TEMPLATE_DEPTH
),
"cpp_template_depth": os.environ.get("GT_CPP_TEMPLATE_DEPTH", GT_CPP_TEMPLATE_DEPTH),
}
if GT4PY_USE_HIP:
build_settings["cuda_library_path"] = os.path.join(CUDA_ROOT, "lib")
Expand All @@ -80,9 +78,7 @@
"dir_name": os.environ.get("GT_CACHE_DIR_NAME", ".gt_cache"),
"root_path": os.environ.get("GT_CACHE_ROOT", os.path.abspath(".")),
"load_retries": int(os.environ.get("GT_CACHE_LOAD_RETRIES", 3)),
"load_retry_delay": int(
os.environ.get("GT_CACHE_LOAD_RETRY_DELAY", 100)
), # unit milliseconds
"load_retry_delay": int(os.environ.get("GT_CACHE_LOAD_RETRY_DELAY", 100)), # unit milliseconds
}

code_settings: Dict[str, Any] = {"root_package_name": "_GT_"}
Expand Down

0 comments on commit ba77632

Please sign in to comment.