\n",
+ " "
+ ],
+ "text/plain": [
+ "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-09-04 17:05:46.580091: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-09-04 17:05:47.357973: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ ]
+ }
+ ],
+ "source": [
+ "import setup\n",
+ "\n",
+ "setup.main()\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "%load_ext jupyter_black\n",
+ "\n",
+ "import torch\n",
+ "import os\n",
+ "from nnvision.models.ptrmodels import task_core_gauss_readout\n",
+ "from mei.modules import EnsembleModel"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "weights_path = os.path.join(os.getcwd(), \"datasets/digital_twins/weights\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def load_v4_model_from_weights(base_dir):\n",
+ " model_fn = task_core_gauss_readout\n",
+ " model_config = {\n",
+ " \"input_channels\": 1,\n",
+ " \"model_name\": \"resnet50_l2_eps0_1\",\n",
+ " \"layer_name\": \"layer3.0\",\n",
+ " \"pretrained\": False,\n",
+ " \"bias\": False,\n",
+ " \"final_batchnorm\": True,\n",
+ " \"final_nonlinearity\": True,\n",
+ " \"momentum\": 0.1,\n",
+ " \"fine_tune\": False,\n",
+ " \"init_mu_range\": 0.4,\n",
+ " \"init_sigma_range\": 0.6,\n",
+ " \"readout_bias\": True,\n",
+ " \"gamma_readout\": 3.0,\n",
+ " \"gauss_type\": \"isotropic\",\n",
+ " \"elu_offset\": -1,\n",
+ " }\n",
+ "\n",
+ " data_info = {\n",
+ " \"all_sessions\": {\n",
+ " \"input_dimensions\": torch.Size([64, 1, 100, 100]),\n",
+ " \"input_channels\": 1,\n",
+ " \"output_dimension\": 1244,\n",
+ " \"img_mean\": 124.54466,\n",
+ " \"img_std\": 70.28,\n",
+ " },\n",
+ " }\n",
+ "\n",
+ " # fill the list ensemble names with task driven 01 - 10\n",
+ " ensemble_names = [\n",
+ " \"task_driven_ensemble_model_01.pth.tar\",\n",
+ " \"task_driven_ensemble_model_02.pth.tar\",\n",
+ " \"task_driven_ensemble_model_03.pth.tar\",\n",
+ " \"task_driven_ensemble_model_04.pth.tar\",\n",
+ " \"task_driven_ensemble_model_05.pth.tar\",\n",
+ " ]\n",
+ "\n",
+ " ensemble_models = []\n",
+ "\n",
+ " for f in ensemble_names:\n",
+ " ensemble_filename = os.path.join(base_dir, f)\n",
+ " ensemble_state_dict = torch.load(ensemble_filename)\n",
+ " ensemble_model = model_fn(\n",
+ " seed=0,\n",
+ " dataloaders=None,\n",
+ " **model_config,\n",
+ " data_info=data_info,\n",
+ " )\n",
+ " ensemble_model.load_state_dict(ensemble_state_dict)\n",
+ " ensemble_models.append(ensemble_model)\n",
+ "\n",
+ " task_driven_ensemble = EnsembleModel(*ensemble_models)\n",
+ " return task_driven_ensemble"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "EnsembleModel(EncoderShifter(\n",
+ " (core): TaskDrivenCore3(\n",
+ " (features): Sequential(\n",
+ " (TaskDriven): Sequential(\n",
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
+ " (layer1): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (layer2): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (3): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (OutBatchNorm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (OutNonlin): ReLU(inplace=True)\n",
+ " )\n",
+ " ) [TaskDrivenCore3 regularizers: ]\n",
+ " \n",
+ " (readout): MultipleFullGaussian2d(\n",
+ " (all_sessions): isotropic FullGaussian2d (1024 x 7 x 7 -> 1244) with bias\n",
+ " )\n",
+ "), EncoderShifter(\n",
+ " (core): TaskDrivenCore3(\n",
+ " (features): Sequential(\n",
+ " (TaskDriven): Sequential(\n",
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
+ " (layer1): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (layer2): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (3): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (OutBatchNorm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (OutNonlin): ReLU(inplace=True)\n",
+ " )\n",
+ " ) [TaskDrivenCore3 regularizers: ]\n",
+ " \n",
+ " (readout): MultipleFullGaussian2d(\n",
+ " (all_sessions): isotropic FullGaussian2d (1024 x 7 x 7 -> 1244) with bias\n",
+ " )\n",
+ "), EncoderShifter(\n",
+ " (core): TaskDrivenCore3(\n",
+ " (features): Sequential(\n",
+ " (TaskDriven): Sequential(\n",
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
+ " (layer1): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (layer2): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (3): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (OutBatchNorm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (OutNonlin): ReLU(inplace=True)\n",
+ " )\n",
+ " ) [TaskDrivenCore3 regularizers: ]\n",
+ " \n",
+ " (readout): MultipleFullGaussian2d(\n",
+ " (all_sessions): isotropic FullGaussian2d (1024 x 7 x 7 -> 1244) with bias\n",
+ " )\n",
+ "), EncoderShifter(\n",
+ " (core): TaskDrivenCore3(\n",
+ " (features): Sequential(\n",
+ " (TaskDriven): Sequential(\n",
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
+ " (layer1): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (layer2): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (3): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (OutBatchNorm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (OutNonlin): ReLU(inplace=True)\n",
+ " )\n",
+ " ) [TaskDrivenCore3 regularizers: ]\n",
+ " \n",
+ " (readout): MultipleFullGaussian2d(\n",
+ " (all_sessions): isotropic FullGaussian2d (1024 x 7 x 7 -> 1244) with bias\n",
+ " )\n",
+ "), EncoderShifter(\n",
+ " (core): TaskDrivenCore3(\n",
+ " (features): Sequential(\n",
+ " (TaskDriven): Sequential(\n",
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
+ " (layer1): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (layer2): Sequential(\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (2): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " (3): Bottleneck(\n",
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (0): Bottleneck(\n",
+ " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (downsample): Sequential(\n",
+ " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (OutBatchNorm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (OutNonlin): ReLU(inplace=True)\n",
+ " )\n",
+ " ) [TaskDrivenCore3 regularizers: ]\n",
+ " \n",
+ " (readout): MultipleFullGaussian2d(\n",
+ " (all_sessions): isotropic FullGaussian2d (1024 x 7 x 7 -> 1244) with bias\n",
+ " )\n",
+ "))\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = load_v4_model_from_weights(weights_path)\n",
+ "print(model)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "neurometry",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}