-
Notifications
You must be signed in to change notification settings - Fork 235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Store nn.Parameter
in entropy_models.py
in nn.ParameterList
#284
Conversation
Thank you @mmuckley for the PR! |
torch.tensor([1, 0], dtype=torch.long, device=x.device), | ||
torch.arange(2, x.ndim, dtype=torch.long, device=x.device), | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another possibility:
perm = torch.tensor(
[1, 0, *range(2, x.ndim)], dtype=torch.long, device=x.device
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, might boil down to the same thing under the hood :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @YodaEmbedding, thanks for the suggestion. If it's okay, I would like to argue for the current implementation, as the modification relies on a Python-level range
iterator and unpacking it into a list, which can lead to a lot of Python calls under the hood.
When working with frameworks like torch.jit
and torch.compile
, I often find that these kinds of constructs can be difficult for the compiler, as the most shaky parts of those libraries are around understanding Python. By keeping everything as PyTorch calls, the compilers seem to perform more stably.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good.
The PR looks good; I tested the current implementation with On a semi-related side note, the current ELIC implementation fails with |
I limited the supported torch <2.3 this week, temporarily. I'm getting slightly different results in eval_model video which breaks CI that compares with expected results produced with earlier versions. |
This PR proposes to store parameters in
entropy_models.py
in annn.ParameterList
instead of its current string-based lookup. The primary reason to do so is to makeEntropyBottleneck
more friendly fortorch.compile
, where the current implementation fails to compile for certain backends (in my own experience, dynamo). The primary reason seems to be that the current implementation relies too much on Python strings and class attributes to access the parameters, whereas the new implementation makes this more clear at the PyTorch level, which helps the compiler.A major drawback to the PR merging would be that it breaks backwards compatibility. I've included some state_dict adjustments that would allow loading old checkpoints, but I understand this may not be ideal. Also, new checkpoints would not be loadable by older versions of
compressai
.The PR also includes a compile test for verifying the implementation.
Happy to see this merged or closed, depending on maintainer preference.