-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6ce277a
commit 8304188
Showing
46 changed files
with
16,211 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "134e7f9d", | ||
"metadata": {}, | ||
"source": [ | ||
"# Demo 10: Device\n", | ||
"\n", | ||
"All other demos have by default used device = 'cpu'. In case we want to use cuda, we should pass the device argument to model and dataset." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "7a4ac1e1-84ba-4bc3-91b6-a776a5e7711c", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"cpu\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from kan import KAN, create_dataset\n", | ||
"import torch\n", | ||
"\n", | ||
"torch.use_deterministic_algorithms(False)\n", | ||
"\n", | ||
"#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | ||
"device = 'cpu'\n", | ||
"print(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "2075ef56", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"checkpoint directory created: ./model\n", | ||
"saving model version 0.0\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"| train_loss: 6.83e-01 | test_loss: 7.21e-01 | reg: 1.04e+03 | : 100%|█| 50/50 [00:19<00:00, 2.62it\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"saving model version 0.1\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = KAN(width=[4,100,100,100,1], grid=3, k=3, seed=0).to(device)\n", | ||
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n", | ||
"dataset = create_dataset(f, n_var=4, train_num=1000, device=device)\n", | ||
"\n", | ||
"# train the model\n", | ||
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n", | ||
"model.fit(dataset, opt=\"Adam\", lr=1e-3, steps=50, lamb=1e-3, lamb_entropy=5., update_grid=False);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2f182cc1-51bf-4151-a253-a52fe854919e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "f6f8125e-d26d-4c97-9e5f-988099bb4737", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"cuda\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"device = 'cuda'\n", | ||
"print(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "95017dfa-3a2a-43e0-8b68-fb220ca5abc9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"checkpoint directory created: ./model\n", | ||
"saving model version 0.0\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"| train_loss: 6.83e-01 | test_loss: 7.21e-01 | reg: 1.04e+03 | : 100%|█| 50/50 [00:01<00:00, 26.90it\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"saving model version 0.1\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = KAN(width=[4,100,100,100,1], grid=3, k=3, seed=0).to(device)\n", | ||
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n", | ||
"dataset = create_dataset(f, n_var=4, train_num=1000, device=device)\n", | ||
"\n", | ||
"# train the model\n", | ||
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n", | ||
"model.fit(dataset, opt=\"Adam\", lr=1e-3, steps=50, lamb=1e-3, lamb_entropy=5., update_grid=False);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8230d562-2635-4adc-b566-06ac679b166a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "53ff2e87", | ||
"metadata": {}, | ||
"source": [ | ||
"# API 11: Create dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "25a90774", | ||
"metadata": {}, | ||
"source": [ | ||
"how to use create_dataset in kan.utils" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2f9ae0c7", | ||
"metadata": {}, | ||
"source": [ | ||
"Standard way" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "3e2b9f8b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"torch.Size([1000, 1])" | ||
] | ||
}, | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from kan.utils import create_dataset\n", | ||
"\n", | ||
"f = lambda x: x[:,[0]] * x[:,[1]]\n", | ||
"dataset = create_dataset(f, n_var=2)\n", | ||
"dataset['train_label'].shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "877956c9", | ||
"metadata": {}, | ||
"source": [ | ||
"Lazier way. We sometimes forget to add the bracket, i.e., write x[:,[0]] as x[:,0], and this used to lead to an error in training (loss not going down). Now the create_dataset can automatically detect this simplification and produce the correct behavior." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "b14dd4a2", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"torch.Size([1000, 1])" | ||
] | ||
}, | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"f = lambda x: x[:,0] * x[:,1]\n", | ||
"dataset = create_dataset(f, n_var=2)\n", | ||
"dataset['train_label'].shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "60230da4", | ||
"metadata": {}, | ||
"source": [ | ||
"Laziest way. If you even want to get rid of the colon symbol, i.e., you want to write x[;,0] as x[0], you can do that but need to pass in f_mode = 'row'." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "e764f415", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"torch.Size([1000, 1])" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"f = lambda x: x[0] * x[1]\n", | ||
"dataset = create_dataset(f, n_var=2, f_mode='row')\n", | ||
"dataset['train_label'].shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8e1f1732", | ||
"metadata": {}, | ||
"source": [ | ||
"if you already have x (inputs) and y (outputs), and you only want to partition them into train/test, use create_dataset_from_data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "accf900a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from kan.utils import create_dataset_from_data\n", | ||
"\n", | ||
"x = torch.rand(100,2)\n", | ||
"y = torch.rand(100,1)\n", | ||
"dataset = create_dataset_from_data(x, y)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c45062a8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.