Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed-Domino #929

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open

Conversation

zhangsmallshark
Copy link

Hello team, Deepspeed-Domino contains all related files for Domino project.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First thing, change name folder Deepspeed-Domino to DeepSpeed-Domino

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we using any function in this file? if not, delete it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, please remove all ._DS_Store or irrelevant files.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

return buffer_tensor


class DistributedDataParallel(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this different from pytorch ddp? if so do we really need the diff part?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is different from pytorch ddp

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch already support native fp32, fp16 dtype transfer, do we really need these?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use native function to replace this one.

linear_layer.bias.zero_()
return linear_layer

def param_is_not_shared(param):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we supporting not_shared param group??

_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'


def _set_cuda_rng_state(new_state, device=-1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we using cuda RNG? I remember it cannot be used together with cudagraph, but can be used together if cudagraph not enabled.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using it. it cannot be used together with cudagraph.

return get_attr_wrapped_model(model, 'config', allow_none=False)


def param_is_not_shared(param):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as above, do we support this "param not shared" feature?

return averaged_losses


def _kernel_make_viewless_tensor(inp, requires_grad):
Copy link
Member

@GuanhuaWang GuanhuaWang Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, but I remember we discussed this before, make viewless tensor slower e2e time thus we disabled it? can @zhangsmallshark you confirm this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have command places where we call viewless functions. I will remove it.

# export NCCL_SOCKET_NTHREADS=4
# export NCCL_NSOCKS_PERTHREAD=8

# cd /work/guanhua/domino
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please clean up more thoroughly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

Copy link
Member

@GuanhuaWang GuanhuaWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx @zhangsmallshark and @shenzheyu for great work.

added a few high level comments, we need to make loss and iter time both fixed! thx

@GuanhuaWang
Copy link
Member

@zhangsmallshark , regarding to fix loss commit da0c63b

Maybe I miss something, but I don't see any real code change regarding to fwd/bwd/step. The only changes in this commit just add timers, comment some printout vals. idk how loss is fixed in this commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants