Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Decoding refactoring #152

Closed
wants to merge 7 commits into from
Closed

[Feat] Decoding refactoring #152

wants to merge 7 commits into from

Conversation

fedebotu
Copy link
Member

@fedebotu fedebotu commented Apr 5, 2024

Description

Several refactorings and features made for decoding in RL4CO.

  1. [Refactoring] Now models return logits by default (e.g. here). We do so since logits represent the raw outputs from the model, and we would like to decouple the modeling part to how we sample distributions. The function handling the transfer from logits to probabilities (hence the "log_p") is logit_to_probs
  2. [Feat] New decoding strategy: we introduce nucleus sampling (i.e. top-p sampling) which discards from the distribution values under a certain threshold in the CDF before sampling. This can be used by simply passing a top_p > 0 to the DecodingStrategy, i.e. to the model deooder. This is ubiquitous in LLMs and it is about time to have it!
  3. [Refactoring]: for simplicity we now default to handling probabilities instead of log probabilities (example here). This is a minor change, but it can make the code more readable and avoid having to do logp.exp() when sampling. This is also more in line with recent works in e.g. LLM
  4. [Refactoring, breaking change] now by default any mask has the same behavior (example here), i.e., the value 1 means keep (i.e. feasible action) while 0 means to remove, i.e. infeasible. This is both similar to TorchRL's action mask and importantly to PyTorch's scaled_dot_product_attention: "A boolean mask where a value of True indicates that the element should take part in attention. " (ref). For this reason, masks that used to have inconsistent namings now have the same behavior
  5. [Minor] Rename LogitAttention to PointerAttention (for consistency with the Pointer mechanism in Vinyals et al., 2015)

Warning

Work in progress. Do not merge yet. Some checks and training still have some bugs that need to be fixed (most probably due to the new masking

Types of changes

  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Checklist

  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

CC: @LTluttmann could you have a look if you spot some inefficiencies or if you have some ideas?
CC: @Furffico @cbhua these changess are what I was talking about yesterday (note that in this case running the softmax normalization inside the Sampling in ACO might not be needed)

@fedebotu
Copy link
Member Author

fedebotu commented Apr 9, 2024

Closed in favor of #161

@fedebotu fedebotu closed this Apr 9, 2024
@fedebotu fedebotu deleted the refactor-decoding branch September 3, 2024 05:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant