-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MWE fortran python #33
base: main
Are you sure you want to change the base?
Conversation
66a5973
to
4bc755a
Compare
@@ -16,7 +13,7 @@ def Inference_and_Save_ANN_CNN(model, testset, testloader, bs_test, device, sten | |||
nx = len(lon) | |||
|
|||
model.eval() | |||
model.dropout.train() # this enables dropout during inference. By default dropout is OFF when model.eval()=True | |||
# model.dropout.train() # this enables dropout during inference. By default dropout is OFF when model.eval()=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amangupta2 , is this ok to disable? or do we need it for inference
quick_inference_script/infer.f90
Outdated
call torch_model_load(model, "saved_nlgw_model_gpu.pt", device_type=torch_kCUDA, device_index=0) | ||
|
||
! Infer | ||
! do i = 1, 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this loop can be removed
4bc755a
to
7d91f97
Compare
a5cecba
to
c389d75
Compare
# if self.fac == 1: | ||
# x = torch.squeeze(self.dropout0(self.act_cnn(self.conv1(x)))) | ||
# elif self.fac == 2: | ||
# x = torch.squeeze(self.dropout0(self.act_cnn(self.conv1(x)))) | ||
# x = torch.squeeze(self.dropout0_2(self.act_cnn2(self.conv2(x)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should find a proper solution to this for torchscripting
@@ -189,6 +192,22 @@ def Inference_and_Save_ANN_CNN(model, testset, testloader, bs_test, device, sten | |||
OUT = OUT.reshape(T[0] * T[1], -1) | |||
PRED = model(INP) | |||
|
|||
print("saving data...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this part optional via command line args
This PR adds minimum worked examples of inference in python and fortran
README.md
closes #11
closes #25