Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPIRV flags are not set for the second call of atomics fetch_add call #1262

Open
ZzEeKkAa opened this issue Dec 28, 2023 · 2 comments
Open
Labels
bug Something isn't working enhancement New feature or request
Milestone

Comments

@ZzEeKkAa
Copy link
Contributor

ZzEeKkAa commented Dec 28, 2023

When two different kernels that call the same overloaded function, e.g., fetch_add in the reproducer, are compiled, the extra compilation flags needed for llvm-spirv translation are only applied to the kernel that is compiled first.

For the first kernel when the fetch_add overload is not available in the compiled cache, it is compiled and during the compilation of the fetch_add function the intrinsic function adds the extra flags to the target context's compilation options. The next time since, a compiled version of fetch_add is available in the dispatcher's overload cache the intrinsic is not invoked. Thus, the extra compilation flags are never updated and an internal compiler error is raised during llvm-spriv translation.

A minimal reproducer:

@dpex_exp.kernel
def atomic_ref_0(a):
    i = dpex.get_global_id(0)
    v = AtomicRef(a, index=0)
    v.fetch_add(a[i + 2])


@dpex_exp.kernel
def atomic_ref_1(a):
    i = dpex.get_global_id(0)
    v = AtomicRef(a, index=1)
    v.fetch_add(a[i + 2])


def test_spirv_flags():
    N = 10
    a = dpnp.ones(N, dtype=dpnp.float32)

    dpex_exp.call_kernel(atomic_ref_0, dpex.Range(N - 2), a) # flags set here because intrinsic is not cached
    # SPIRV flags then being removed in `spirv_generator.Module.finalize`.
    dpex_exp.call_kernel(atomic_ref_1, dpex.Range(N - 2), a) # intrinsic is already compiled and cached, so flags are not set

    assert a[0] == N - 1
    assert a[1] == N - 1

I did some investigation and for old style it works. My guess is that old style uses lower instead of intrinsic.

In order to reproduce update llvm_spirv_args to empty list at spirv_genrator.py. Search for the reference to this issue.

@ZzEeKkAa ZzEeKkAa added this to the 0.23 milestone Dec 28, 2023
@ZzEeKkAa ZzEeKkAa added bug Something isn't working enhancement New feature or request labels Dec 28, 2023
@diptorupd
Copy link
Contributor

diptorupd commented Dec 28, 2023

@ZzEeKkAa thank you for the reproducer. The issue with the overload PR is now clear to me.

The issue happens because overloads are not compiled to SPIR-V. We do SPIR-V compilation only for kernel functions after all overloads compiled to LLVM are linked to the kernel function at the level of LLVM bitcode.

As a solution, when we compile an overload, e.g., fetch_add, any extra compilation flag should be stored as part of the CompileResult for that overload. When a kernel code library is finalized by the dispatcher, the process should gather all extra compilation flags for every library that is linked into the final code library for the kernel. Then the llvm-spriv should be invoked correctly with all needed flags.

@diptorupd diptorupd self-assigned this Dec 28, 2023
@ZzEeKkAa
Copy link
Contributor Author

ZzEeKkAa commented Jan 5, 2024

@diptorupd I like the idea. I know it will depend on realization but we need to keep those flag populated for any other overload wrapper. I'm saying if any overload uses overload with compilation flags we need to store those compilation flags also on caller level overload.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants