-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support loading BF16 models in Cuda on T4 #1581
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm personnally not a big fan of all these hoops.
supports_native_bf16
as a function could be interesting (so users could create logics as presented in this PR), but it probably doesn't need to launch an actual kernel for it.
I fail to see the need for the extra copy pasted kernel. Maybe we could just relax the CUDA_ARCH__ >= 800 instead of creating a new function ?
if we implement supports_native_bf16
we need to make it correct on other devices too (can be done in a followup)
let elem_count = 1; | ||
let cfg = LaunchConfig::for_num_elems(elem_count as u32); | ||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?; | ||
let func = self.get_or_load_func("support_native_bf16", kernels::CAST)?; | ||
let v: u8 = 0; | ||
let params = (&data, v, elem_count); | ||
unsafe { func.launch(cfg, params) }.w()?; | ||
|
||
let ret = self.dtoh_sync_copy(&data).w()?; | ||
match ret.first() { | ||
None => Ok(false), | ||
Some(r) => Ok(*r == 1u8), | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we not launch an actual kernel ?
We could for instance define support_native_bf16
only on those targets and just check for the existence of the function.
let elem_count = 1; | |
let cfg = LaunchConfig::for_num_elems(elem_count as u32); | |
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?; | |
let func = self.get_or_load_func("support_native_bf16", kernels::CAST)?; | |
let v: u8 = 0; | |
let params = (&data, v, elem_count); | |
unsafe { func.launch(cfg, params) }.w()?; | |
let ret = self.dtoh_sync_copy(&data).w()?; | |
match ret.first() { | |
None => Ok(false), | |
Some(r) => Ok(*r == 1u8), | |
} | |
} | |
self.get_or_load_func("support_native_bf16", kernels::CAST).is_ok() | |
} |
@@ -166,6 +166,13 @@ impl Device { | |||
pub fn is_cuda(&self) -> bool { | |||
matches!(self, Self::Cuda(_)) | |||
} | |||
pub fn support_native_bf16(&self) -> Result<bool> { | |||
match self { | |||
Self::Cpu => Ok(true), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most CPU actually do no support bf16.
match self { | ||
Self::Cpu => Ok(true), | ||
Self::Cuda(c) => c.support_native_bf16(), | ||
Self::Metal(_) => Ok(true), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only some metal versions support bf16
@@ -172,8 +172,10 @@ fn main() -> Result<()> { | |||
let start = std::time::Instant::now(); | |||
let dtype = if args.use_f32 { | |||
DType::F32 | |||
} else { | |||
} else if device.support_native_bf16()? { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should just switch that args
into an actual enum --dtype float16
.
} \ | ||
} \ | ||
|
||
LOCAL_CAST_OP(__nv_bfloat16, float, cast_bf16_f32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If i'm not mistaken this is undefined for some low enough cuda version. Also I fail to see the difference between this version and the one in >= 800
path.
NVidia T4 does not support B16 floating types natively. Add a method to cast BF16 to F32 or F16 so that mistral can be loaded on T4