Skip to content

Commit

Permalink
Regenerate MLIR Bindings (#779)
Browse files Browse the repository at this point in the history
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and enzyme-ci-bot[bot] authored Feb 23, 2025
1 parent 0a7e078 commit 4aaae3c
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/mlir/Dialects/CHLO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1408,12 +1408,12 @@ the lhs is required to have one ragged dimension, and the rhs may have at
most one group dimension. The op has three modes, depending on the kind of
the lhs ragged dimension.
In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`.
In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [b,g] -> [b,m,n]`.
Here the ragged dimension is an lhs non-contracting dimension (`m`). The
dimensions `b` and `k` represent batch and contracting dimensions
respectively. The rhs is required to have a group dimension (`g`).
In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`.
In mode 2, the shape-signature is `[b,m,k], [b,k,n], [b,g] -> [g,b,m,n]`.
Here the ragged dimension is an lhs/rhs contracting dimension (`k`).
In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here
Expand Down
100 changes: 100 additions & 0 deletions src/mlir/Dialects/Nvvm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2990,6 +2990,46 @@ function tcgen05_alloc(addr::Value, nCols::Value; group=nothing, location=Locati
)
end

"""
`tcgen05_commit`
The `tcgen05.commit` makes the mbarrier object, specified by
the operand `addr`, track the completion of all the prior
async-tcgen05 operations initiated by the executing thread.
The multicast variants allow signaling on the mbarrier objects
of multiple CTAs within the cluster. Operand `multicastMask`,
when present, specifies the destination CTAs in the cluster such
that each bit position in the 16-bit `multicastMask` operand
corresponds to the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
[For more information refer PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
"""
function tcgen05_commit(
addr::Value,
multicastMask=nothing::Union{Nothing,Value};
group=nothing,
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[addr,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(multicastMask) && push!(operands, multicastMask)
!isnothing(group) && push!(attributes, namedattribute("group", group))

return create_operation(
"nvvm.tcgen05.commit",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`tcgen05_dealloc`
Expand Down Expand Up @@ -3020,6 +3060,36 @@ function tcgen05_dealloc(taddr::Value, nCols::Value; group=nothing, location=Loc
)
end

"""
`tcgen05_fence`
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
with respect to the subsequent tcgen05 and execution ordering operations.
The `tcgen05.fence<after>` orders all subsequent async tcgen05 operations
with respect to the prior tcgen05 and execution ordering operations.
[For more information refer to the PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-fence)
"""
function tcgen05_fence(; kind, location=Location())
op_ty_results = IR.Type[]
operands = Value[]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("kind", kind),]

return create_operation(
"nvvm.tcgen05.fence",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`tcgen05_relinquish_alloc_permit`
Expand Down Expand Up @@ -3050,6 +3120,36 @@ function tcgen05_relinquish_alloc_permit(; group=nothing, location=Location())
)
end

"""
`tcgen05_wait`
The `tcgen05.wait<load>` causes the executing thread to block until
all prior `tcgen05.ld` operations issued by the executing thread
have completed. Similarly, the `tcgen05.wait<store>` causes the executing
thread to block until all prior `tcgen05.st` operations issued by the
executing thread have completed.
[For more information refer PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-wait)
"""
function tcgen05_wait(; kind, location=Location())
op_ty_results = IR.Type[]
operands = Value[]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("kind", kind),]

return create_operation(
"nvvm.tcgen05.wait",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function read_ptx_sreg_tid_x(; res::IR.Type, range=nothing, location=Location())
op_ty_results = IR.Type[res,]
operands = Value[]
Expand Down
19 changes: 19 additions & 0 deletions src/mlir/Dialects/TPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,25 @@ function vector_store(
)
end

function wait_dma2(semaphore::Value, src::Value, dst::Value; location=Location())
op_ty_results = IR.Type[]
operands = Value[semaphore, src, dst]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"tpu.wait_dma2",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function wait_dma(semaphore::Value, ref::Value; location=Location())
op_ty_results = IR.Type[]
operands = Value[semaphore, ref]
Expand Down
20 changes: 20 additions & 0 deletions src/mlir/libMLIR_h.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6247,6 +6247,10 @@ function mlirLoadIRDLDialects(_module)
@ccall mlir_c.mlirLoadIRDLDialects(_module::MlirModule)::MlirLogicalResult
end

function mlirGetDialectHandle__index__()
@ccall mlir_c.mlirGetDialectHandle__index__()::MlirDialectHandle
end

function mlirGetDialectHandle__llvm__()
@ccall mlir_c.mlirGetDialectHandle__llvm__()::MlirDialectHandle
end
Expand Down Expand Up @@ -10112,6 +10116,8 @@ function sdyOpShardingRuleAttrGet(
reductionFactors,
nNeedReplicationFactors,
needReplicationFactors,
nPermutationFactors,
permutationFactors,
isCustomRule,
)
@ccall mlir_c.sdyOpShardingRuleAttrGet(
Expand All @@ -10126,6 +10132,8 @@ function sdyOpShardingRuleAttrGet(
reductionFactors::Ptr{Int64},
nNeedReplicationFactors::intptr_t,
needReplicationFactors::Ptr{Int64},
nPermutationFactors::intptr_t,
permutationFactors::Ptr{Int64},
isCustomRule::Bool,
)::MlirAttribute
end
Expand Down Expand Up @@ -10188,6 +10196,18 @@ function sdyOpShardingRuleAttrGetNeedReplicationFactorsElem(attr, pos)
)::Int64
end

function sdyOpShardingRuleAttrGetPermutationFactorsSize(attr)
@ccall mlir_c.sdyOpShardingRuleAttrGetPermutationFactorsSize(
attr::MlirAttribute
)::intptr_t
end

function sdyOpShardingRuleAttrGetPermutationFactorsElem(attr, pos)
@ccall mlir_c.sdyOpShardingRuleAttrGetPermutationFactorsElem(
attr::MlirAttribute, pos::intptr_t
)::Int64
end

function sdyAttributeIsAManualAxesAttr(attr)
@ccall mlir_c.sdyAttributeIsAManualAxesAttr(attr::MlirAttribute)::Bool
end
Expand Down

0 comments on commit 4aaae3c

Please sign in to comment.