diff --git a/inference.py b/inference.py index 08eb935..54dd3ac 100644 --- a/inference.py +++ b/inference.py @@ -19,8 +19,8 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - model = torch.load("model.pt") - parameters = torch.load("auxiliary.pt") + model = torch.load("models/model.pt") + parameters = torch.load("models/auxiliary.pt") model.eval() test = CNSDataset( args.data_file, transform=DescriptorGenerator(AVAILABLE_DESCRIPTORS)