Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

investigate if conditions in torch script #42

Open
1 of 2 tasks
TomMelt opened this issue Feb 5, 2025 · 2 comments
Open
1 of 2 tasks

investigate if conditions in torch script #42

TomMelt opened this issue Feb 5, 2025 · 2 comments
Assignees
Labels

Comments

@TomMelt
Copy link
Collaborator

TomMelt commented Feb 5, 2025

Currently #33 hacks the utils/model_definition.py to disable an if condition that isn't used for 1x1 model but causes issue when trying to script

# if self.fac == 1:
# x = torch.squeeze(self.dropout0(self.act_cnn(self.conv1(x))))
# elif self.fac == 2:
# x = torch.squeeze(self.dropout0(self.act_cnn(self.conv1(x))))
# x = torch.squeeze(self.dropout0_2(self.act_cnn2(self.conv2(x))))

Error:

RuntimeError: 
Module 'ANN_CNN' has no attribute 'dropout0' :
  File "/home/melt/sync/cambridge/projects/current/nlgw-cam/nonlocal_gwfluxes/era5_training/../utils/model_definition.py", line 75
    def forward(self, x):
        if self.fac == 1:
            x = torch.squeeze(self.dropout0(self.act_cnn(self.conv1(x))))
                              ~~~~~~~~~~~~~ <--- HERE
        elif self.fac == 2:
            x = torch.squeeze(self.dropout0(self.act_cnn(self.conv1(x))))
  • investigate cause of error
  • find permanent solution rather than commenting the code out
@TomMelt
Copy link
Collaborator Author

TomMelt commented Feb 6, 2025

According to the PyTorch docs we will have to modify the code:

TorchScript supports a subset of Python’s variable resolution (i.e. scoping) rules. Local variables behave the same as in Python, except for the restriction that a variable must have the same type along all paths through a function. If a variable has a different type on different branches of an if statement, it is an error to use it after the end of the if statement.
Similarly, a variable is not allowed to be used if it is only defined along some paths through the function.

For example, the following code:

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)

will give this error

Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...

@TomMelt
Copy link
Collaborator Author

TomMelt commented Feb 20, 2025

@TomMelt to look into proper work around by pre-defining unused vars

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant