Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading BF16 models in Cuda on T4 #1581

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ltouati
Copy link
Contributor

@ltouati ltouati commented Jan 13, 2024

NVidia T4 does not support B16 floating types natively. Add a method to cast BF16 to F32 or F16 so that mistral can be loaded on T4

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personnally not a big fan of all these hoops.

supports_native_bf16 as a function could be interesting (so users could create logics as presented in this PR), but it probably doesn't need to launch an actual kernel for it.

I fail to see the need for the extra copy pasted kernel. Maybe we could just relax the CUDA_ARCH__ >= 800 instead of creating a new function ?

if we implement supports_native_bf16 we need to make it correct on other devices too (can be done in a followup)

Comment on lines +123 to +136
let elem_count = 1;
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("support_native_bf16", kernels::CAST)?;
let v: u8 = 0;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;

let ret = self.dtoh_sync_copy(&data).w()?;
match ret.first() {
None => Ok(false),
Some(r) => Ok(*r == 1u8),
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we not launch an actual kernel ?

We could for instance define support_native_bf16 only on those targets and just check for the existence of the function.

Suggested change
let elem_count = 1;
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("support_native_bf16", kernels::CAST)?;
let v: u8 = 0;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let ret = self.dtoh_sync_copy(&data).w()?;
match ret.first() {
None => Ok(false),
Some(r) => Ok(*r == 1u8),
}
}
self.get_or_load_func("support_native_bf16", kernels::CAST).is_ok()
}

@@ -166,6 +166,13 @@ impl Device {
pub fn is_cuda(&self) -> bool {
matches!(self, Self::Cuda(_))
}
pub fn support_native_bf16(&self) -> Result<bool> {
match self {
Self::Cpu => Ok(true),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most CPU actually do no support bf16.

match self {
Self::Cpu => Ok(true),
Self::Cuda(c) => c.support_native_bf16(),
Self::Metal(_) => Ok(true),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only some metal versions support bf16

@@ -172,8 +172,10 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let dtype = if args.use_f32 {
DType::F32
} else {
} else if device.support_native_bf16()? {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just switch that args into an actual enum --dtype float16.

} \
} \

LOCAL_CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If i'm not mistaken this is undefined for some low enough cuda version. Also I fail to see the difference between this version and the one in >= 800 path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants