diff --git a/.sqlx/query-ccea8776d7cdc2d6c87ecd54c36113b3e829aad2a0cdbfa00b83a81158b9aa96.json b/.sqlx/query-24ed36f7df12252f37652518c35c4aa2ffde87118e17ed90f0ed72622eae6c99.json similarity index 85% rename from .sqlx/query-ccea8776d7cdc2d6c87ecd54c36113b3e829aad2a0cdbfa00b83a81158b9aa96.json rename to .sqlx/query-24ed36f7df12252f37652518c35c4aa2ffde87118e17ed90f0ed72622eae6c99.json index 7b82f6643..73cf19c5f 100644 --- a/.sqlx/query-ccea8776d7cdc2d6c87ecd54c36113b3e829aad2a0cdbfa00b83a81158b9aa96.json +++ b/.sqlx/query-24ed36f7df12252f37652518c35c4aa2ffde87118e17ed90f0ed72622eae6c99.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE openidprovider SET name = $1, base_url = $2, client_id = $3, client_secret = $4, display_name = $5, google_service_account_key = $6, google_service_account_email = $7, admin_email = $8, directory_sync_enabled = $9, directory_sync_interval = $10, directory_sync_user_behavior = $11, directory_sync_admin_behavior = $12, directory_sync_target = $13 WHERE id = $14", + "query": "UPDATE openidprovider SET name = $1, base_url = $2, client_id = $3, client_secret = $4, display_name = $5, google_service_account_key = $6, google_service_account_email = $7, admin_email = $8, directory_sync_enabled = $9, directory_sync_interval = $10, directory_sync_user_behavior = $11, directory_sync_admin_behavior = $12, directory_sync_target = $13, okta_private_jwk = $14, okta_dirsync_client_id = $15 WHERE id = $16", "describe": { "columns": [], "parameters": { @@ -51,10 +51,12 @@ } } }, + "Text", + "Text", "Int8" ] }, "nullable": [] }, - "hash": "ccea8776d7cdc2d6c87ecd54c36113b3e829aad2a0cdbfa00b83a81158b9aa96" + "hash": "24ed36f7df12252f37652518c35c4aa2ffde87118e17ed90f0ed72622eae6c99" } diff --git a/.sqlx/query-42ea85e353deb1555b4e442a2fcdf366eb24bf7b907011a01313b77bed572176.json b/.sqlx/query-42ea85e353deb1555b4e442a2fcdf366eb24bf7b907011a01313b77bed572176.json new file mode 100644 index 000000000..2df04febf --- /dev/null +++ b/.sqlx/query-42ea85e353deb1555b4e442a2fcdf366eb24bf7b907011a01313b77bed572176.json @@ -0,0 +1,53 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT device_id, wireguard_network_id, wireguard_ip \"wireguard_ip: IpAddr\", preshared_key, is_authorized, authorized_at FROM wireguard_network_device WHERE wireguard_network_id = $1 AND device_id IN (SELECT id FROM device WHERE user_id = $2 AND device_type = 'user'::device_type)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "device_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "wireguard_network_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "wireguard_ip: IpAddr", + "type_info": "Inet" + }, + { + "ordinal": 3, + "name": "preshared_key", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "is_authorized", + "type_info": "Bool" + }, + { + "ordinal": 5, + "name": "authorized_at", + "type_info": "Timestamp" + } + ], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + true + ] + }, + "hash": "42ea85e353deb1555b4e442a2fcdf366eb24bf7b907011a01313b77bed572176" +} diff --git a/.sqlx/query-7e0df54ab4876f49960b16747ffa05c748bd2b09683251c45728e227a320b4bc.json b/.sqlx/query-7f09a8e817e8e2df2a5c4896b4d0fad03c0cac68d8b597b55a2fb0ccc4d2cb15.json similarity index 86% rename from .sqlx/query-7e0df54ab4876f49960b16747ffa05c748bd2b09683251c45728e227a320b4bc.json rename to .sqlx/query-7f09a8e817e8e2df2a5c4896b4d0fad03c0cac68d8b597b55a2fb0ccc4d2cb15.json index d7ef4b0fa..5b269476f 100644 --- a/.sqlx/query-7e0df54ab4876f49960b16747ffa05c748bd2b09683251c45728e227a320b4bc.json +++ b/.sqlx/query-7f09a8e817e8e2df2a5c4896b4d0fad03c0cac68d8b597b55a2fb0ccc4d2cb15.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\",\"google_service_account_key\",\"google_service_account_email\",\"admin_email\",\"directory_sync_enabled\",\"directory_sync_interval\",\"directory_sync_user_behavior\" \"directory_sync_user_behavior: _\",\"directory_sync_admin_behavior\" \"directory_sync_admin_behavior: _\",\"directory_sync_target\" \"directory_sync_target: _\" FROM \"openidprovider\" WHERE id = $1", + "query": "SELECT id, \"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\",\"google_service_account_key\",\"google_service_account_email\",\"admin_email\",\"directory_sync_enabled\",\"directory_sync_interval\",\"directory_sync_user_behavior\" \"directory_sync_user_behavior: _\",\"directory_sync_admin_behavior\" \"directory_sync_admin_behavior: _\",\"directory_sync_target\" \"directory_sync_target: _\",\"okta_private_jwk\",\"okta_dirsync_client_id\" FROM \"openidprovider\" WHERE id = $1", "describe": { "columns": [ { @@ -105,6 +105,16 @@ } } } + }, + { + "ordinal": 14, + "name": "okta_private_jwk", + "type_info": "Text" + }, + { + "ordinal": 15, + "name": "okta_dirsync_client_id", + "type_info": "Text" } ], "parameters": { @@ -126,8 +136,10 @@ false, false, false, - false + false, + true, + true ] }, - "hash": "7e0df54ab4876f49960b16747ffa05c748bd2b09683251c45728e227a320b4bc" + "hash": "7f09a8e817e8e2df2a5c4896b4d0fad03c0cac68d8b597b55a2fb0ccc4d2cb15" } diff --git a/.sqlx/query-bd1ba52cce3e0529d93f54da70cc77b3a2445b02bba1b6de399fa576acd82025.json b/.sqlx/query-8837d69c8bdc3223611c936a21c5e0f68aa815e09d48e2cd79985d62bd02711a.json similarity index 88% rename from .sqlx/query-bd1ba52cce3e0529d93f54da70cc77b3a2445b02bba1b6de399fa576acd82025.json rename to .sqlx/query-8837d69c8bdc3223611c936a21c5e0f68aa815e09d48e2cd79985d62bd02711a.json index 3dd332649..510bbe99b 100644 --- a/.sqlx/query-bd1ba52cce3e0529d93f54da70cc77b3a2445b02bba1b6de399fa576acd82025.json +++ b/.sqlx/query-8837d69c8bdc3223611c936a21c5e0f68aa815e09d48e2cd79985d62bd02711a.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, name, base_url, client_id, client_secret, display_name, google_service_account_key, google_service_account_email, admin_email, directory_sync_enabled, directory_sync_interval, directory_sync_user_behavior \"directory_sync_user_behavior: DirectorySyncUserBehavior\", directory_sync_admin_behavior \"directory_sync_admin_behavior: DirectorySyncUserBehavior\", directory_sync_target \"directory_sync_target: DirectorySyncTarget\" FROM openidprovider LIMIT 1", + "query": "SELECT id, name, base_url, client_id, client_secret, display_name, google_service_account_key, google_service_account_email, admin_email, directory_sync_enabled, directory_sync_interval, directory_sync_user_behavior \"directory_sync_user_behavior: DirectorySyncUserBehavior\", directory_sync_admin_behavior \"directory_sync_admin_behavior: DirectorySyncUserBehavior\", directory_sync_target \"directory_sync_target: DirectorySyncTarget\", okta_private_jwk, okta_dirsync_client_id FROM openidprovider LIMIT 1", "describe": { "columns": [ { @@ -105,6 +105,16 @@ } } } + }, + { + "ordinal": 14, + "name": "okta_private_jwk", + "type_info": "Text" + }, + { + "ordinal": 15, + "name": "okta_dirsync_client_id", + "type_info": "Text" } ], "parameters": { @@ -124,8 +134,10 @@ false, false, false, - false + false, + true, + true ] }, - "hash": "bd1ba52cce3e0529d93f54da70cc77b3a2445b02bba1b6de399fa576acd82025" + "hash": "8837d69c8bdc3223611c936a21c5e0f68aa815e09d48e2cd79985d62bd02711a" } diff --git a/.sqlx/query-93ba83260f46d538c2063ec49bc12dd0f2a64cb2eb7f2dd13b9cba4ae441f70e.json b/.sqlx/query-99321159e98b8e4c4c3c8210b9f230776bab957e26654cc62ca2a9874df6c942.json similarity index 87% rename from .sqlx/query-93ba83260f46d538c2063ec49bc12dd0f2a64cb2eb7f2dd13b9cba4ae441f70e.json rename to .sqlx/query-99321159e98b8e4c4c3c8210b9f230776bab957e26654cc62ca2a9874df6c942.json index e4436fe45..104705bbd 100644 --- a/.sqlx/query-93ba83260f46d538c2063ec49bc12dd0f2a64cb2eb7f2dd13b9cba4ae441f70e.json +++ b/.sqlx/query-99321159e98b8e4c4c3c8210b9f230776bab957e26654cc62ca2a9874df6c942.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, name, base_url, client_id, client_secret, display_name, google_service_account_key, google_service_account_email, admin_email, directory_sync_enabled, \n directory_sync_interval, directory_sync_user_behavior \"directory_sync_user_behavior: DirectorySyncUserBehavior\", directory_sync_admin_behavior \"directory_sync_admin_behavior: DirectorySyncUserBehavior\", directory_sync_target \"directory_sync_target: DirectorySyncTarget\" FROM openidprovider WHERE name = $1", + "query": "SELECT id, name, base_url, client_id, client_secret, display_name, google_service_account_key, google_service_account_email, admin_email, directory_sync_enabled, \n directory_sync_interval, directory_sync_user_behavior \"directory_sync_user_behavior: DirectorySyncUserBehavior\", directory_sync_admin_behavior \"directory_sync_admin_behavior: DirectorySyncUserBehavior\", directory_sync_target \"directory_sync_target: DirectorySyncTarget\", okta_private_jwk, okta_dirsync_client_id FROM openidprovider WHERE name = $1", "describe": { "columns": [ { @@ -105,6 +105,16 @@ } } } + }, + { + "ordinal": 14, + "name": "okta_private_jwk", + "type_info": "Text" + }, + { + "ordinal": 15, + "name": "okta_dirsync_client_id", + "type_info": "Text" } ], "parameters": { @@ -126,8 +136,10 @@ false, false, false, - false + false, + true, + true ] }, - "hash": "93ba83260f46d538c2063ec49bc12dd0f2a64cb2eb7f2dd13b9cba4ae441f70e" + "hash": "99321159e98b8e4c4c3c8210b9f230776bab957e26654cc62ca2a9874df6c942" } diff --git a/.sqlx/query-a2321eb640aa9302dd1ca51498a5d9c6147ceca80eb638037e548fb821c12968.json b/.sqlx/query-a2321eb640aa9302dd1ca51498a5d9c6147ceca80eb638037e548fb821c12968.json new file mode 100644 index 000000000..aaedff761 --- /dev/null +++ b/.sqlx/query-a2321eb640aa9302dd1ca51498a5d9c6147ceca80eb638037e548fb821c12968.json @@ -0,0 +1,75 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT DISTINCT ON (d.id) d.id, d.name, d.wireguard_pubkey, d.user_id, d.created, d.description, d.device_type \"device_type: DeviceType\", configured\n FROM device d JOIN \"user\" u ON d.user_id = u.id JOIN group_user gu ON u.id = gu.user_id JOIN \"group\" g ON gu.group_id = g.id WHERE g.\"name\" IN (SELECT * FROM UNNEST($1::text[])) AND u.is_active = true AND d.device_type = 'user'::device_type AND d.user_id = $2 ORDER BY d.id ASC", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "wireguard_pubkey", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamp" + }, + { + "ordinal": 5, + "name": "description", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "device_type: DeviceType", + "type_info": { + "Custom": { + "name": "device_type", + "kind": { + "Enum": [ + "user", + "network" + ] + } + } + } + }, + { + "ordinal": 7, + "name": "configured", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "TextArray", + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + true, + false, + false + ] + }, + "hash": "a2321eb640aa9302dd1ca51498a5d9c6147ceca80eb638037e548fb821c12968" +} diff --git a/.sqlx/query-77a6abaf34dd64741a1075a700bc510630fa89a072f81ccfe21b6e986f7b2552.json b/.sqlx/query-b657f2e85d3d880ee2d247591c57cbecaa2fe8897d73fba2c8301410853712e7.json similarity index 84% rename from .sqlx/query-77a6abaf34dd64741a1075a700bc510630fa89a072f81ccfe21b6e986f7b2552.json rename to .sqlx/query-b657f2e85d3d880ee2d247591c57cbecaa2fe8897d73fba2c8301410853712e7.json index 1e37c06b1..348f2fb9b 100644 --- a/.sqlx/query-77a6abaf34dd64741a1075a700bc510630fa89a072f81ccfe21b6e986f7b2552.json +++ b/.sqlx/query-b657f2e85d3d880ee2d247591c57cbecaa2fe8897d73fba2c8301410853712e7.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO \"openidprovider\" (\"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\",\"google_service_account_key\",\"google_service_account_email\",\"admin_email\",\"directory_sync_enabled\",\"directory_sync_interval\",\"directory_sync_user_behavior\",\"directory_sync_admin_behavior\",\"directory_sync_target\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13) RETURNING id", + "query": "INSERT INTO \"openidprovider\" (\"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\",\"google_service_account_key\",\"google_service_account_email\",\"admin_email\",\"directory_sync_enabled\",\"directory_sync_interval\",\"directory_sync_user_behavior\",\"directory_sync_admin_behavior\",\"directory_sync_target\",\"okta_private_jwk\",\"okta_dirsync_client_id\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15) RETURNING id", "describe": { "columns": [ { @@ -56,12 +56,14 @@ ] } } - } + }, + "Text", + "Text" ] }, "nullable": [ false ] }, - "hash": "77a6abaf34dd64741a1075a700bc510630fa89a072f81ccfe21b6e986f7b2552" + "hash": "b657f2e85d3d880ee2d247591c57cbecaa2fe8897d73fba2c8301410853712e7" } diff --git a/.sqlx/query-c5691cac4edea09b4cfabd1105be053bb4bb030489f8084707fa0225a2bddce6.json b/.sqlx/query-c5691cac4edea09b4cfabd1105be053bb4bb030489f8084707fa0225a2bddce6.json new file mode 100644 index 000000000..920e9824e --- /dev/null +++ b/.sqlx/query-c5691cac4edea09b4cfabd1105be053bb4bb030489f8084707fa0225a2bddce6.json @@ -0,0 +1,74 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT d.id, d.name, d.wireguard_pubkey, d.user_id, d.created, d.description, d.device_type \"device_type: DeviceType\", configured FROM device d JOIN \"user\" u ON d.user_id = u.id WHERE u.is_active = true AND d.device_type = 'user'::device_type AND d.user_id = $1 ORDER BY d.id ASC", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "wireguard_pubkey", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamp" + }, + { + "ordinal": 5, + "name": "description", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "device_type: DeviceType", + "type_info": { + "Custom": { + "name": "device_type", + "kind": { + "Enum": [ + "user", + "network" + ] + } + } + } + }, + { + "ordinal": 7, + "name": "configured", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + true, + false, + false + ] + }, + "hash": "c5691cac4edea09b4cfabd1105be053bb4bb030489f8084707fa0225a2bddce6" +} diff --git a/.sqlx/query-c858bafd3e74f99d5720a9627f68677e44dea7faf04115dbe8200041b662040c.json b/.sqlx/query-d7745f7087791dcbf8e7ec504ed88ed6ad81d5a2f00e7307d8d338bc2269ba91.json similarity index 86% rename from .sqlx/query-c858bafd3e74f99d5720a9627f68677e44dea7faf04115dbe8200041b662040c.json rename to .sqlx/query-d7745f7087791dcbf8e7ec504ed88ed6ad81d5a2f00e7307d8d338bc2269ba91.json index b42f08b95..e50fc0055 100644 --- a/.sqlx/query-c858bafd3e74f99d5720a9627f68677e44dea7faf04115dbe8200041b662040c.json +++ b/.sqlx/query-d7745f7087791dcbf8e7ec504ed88ed6ad81d5a2f00e7307d8d338bc2269ba91.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE \"openidprovider\" SET \"name\" = $2,\"base_url\" = $3,\"client_id\" = $4,\"client_secret\" = $5,\"display_name\" = $6,\"google_service_account_key\" = $7,\"google_service_account_email\" = $8,\"admin_email\" = $9,\"directory_sync_enabled\" = $10,\"directory_sync_interval\" = $11,\"directory_sync_user_behavior\" = $12,\"directory_sync_admin_behavior\" = $13,\"directory_sync_target\" = $14 WHERE id = $1", + "query": "UPDATE \"openidprovider\" SET \"name\" = $2,\"base_url\" = $3,\"client_id\" = $4,\"client_secret\" = $5,\"display_name\" = $6,\"google_service_account_key\" = $7,\"google_service_account_email\" = $8,\"admin_email\" = $9,\"directory_sync_enabled\" = $10,\"directory_sync_interval\" = $11,\"directory_sync_user_behavior\" = $12,\"directory_sync_admin_behavior\" = $13,\"directory_sync_target\" = $14,\"okta_private_jwk\" = $15,\"okta_dirsync_client_id\" = $16 WHERE id = $1", "describe": { "columns": [], "parameters": { @@ -51,10 +51,12 @@ ] } } - } + }, + "Text", + "Text" ] }, "nullable": [] }, - "hash": "c858bafd3e74f99d5720a9627f68677e44dea7faf04115dbe8200041b662040c" + "hash": "d7745f7087791dcbf8e7ec504ed88ed6ad81d5a2f00e7307d8d338bc2269ba91" } diff --git a/.sqlx/query-9bbdc8988c6ab347d5e0b1808eb547357763cf9a1372f418c57ddc66ffab3697.json b/.sqlx/query-e756cd7ed9f2695631f9162e8c7e49921e469e97b86196b5608b3c79e4c7a7df.json similarity index 86% rename from .sqlx/query-9bbdc8988c6ab347d5e0b1808eb547357763cf9a1372f418c57ddc66ffab3697.json rename to .sqlx/query-e756cd7ed9f2695631f9162e8c7e49921e469e97b86196b5608b3c79e4c7a7df.json index 093e10db2..4c82a7f0e 100644 --- a/.sqlx/query-9bbdc8988c6ab347d5e0b1808eb547357763cf9a1372f418c57ddc66ffab3697.json +++ b/.sqlx/query-e756cd7ed9f2695631f9162e8c7e49921e469e97b86196b5608b3c79e4c7a7df.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\",\"google_service_account_key\",\"google_service_account_email\",\"admin_email\",\"directory_sync_enabled\",\"directory_sync_interval\",\"directory_sync_user_behavior\" \"directory_sync_user_behavior: _\",\"directory_sync_admin_behavior\" \"directory_sync_admin_behavior: _\",\"directory_sync_target\" \"directory_sync_target: _\" FROM \"openidprovider\"", + "query": "SELECT id, \"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\",\"google_service_account_key\",\"google_service_account_email\",\"admin_email\",\"directory_sync_enabled\",\"directory_sync_interval\",\"directory_sync_user_behavior\" \"directory_sync_user_behavior: _\",\"directory_sync_admin_behavior\" \"directory_sync_admin_behavior: _\",\"directory_sync_target\" \"directory_sync_target: _\",\"okta_private_jwk\",\"okta_dirsync_client_id\" FROM \"openidprovider\"", "describe": { "columns": [ { @@ -105,6 +105,16 @@ } } } + }, + { + "ordinal": 14, + "name": "okta_private_jwk", + "type_info": "Text" + }, + { + "ordinal": 15, + "name": "okta_dirsync_client_id", + "type_info": "Text" } ], "parameters": { @@ -124,8 +134,10 @@ false, false, false, - false + false, + true, + true ] }, - "hash": "9bbdc8988c6ab347d5e0b1808eb547357763cf9a1372f418c57ddc66ffab3697" + "hash": "e756cd7ed9f2695631f9162e8c7e49921e469e97b86196b5608b3c79e4c7a7df" } diff --git a/Cargo.lock b/Cargo.lock index f65d3c8f9..8e366d7d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1048,7 +1048,7 @@ dependencies = [ [[package]] name = "defguard" -version = "1.2.2" +version = "1.2.3" dependencies = [ "anyhow", "argon2", @@ -1064,6 +1064,7 @@ dependencies = [ "dotenvy", "humantime", "ipnetwork", + "jsonwebkey", "jsonwebtoken", "ldap3", "lettre", @@ -1072,6 +1073,7 @@ dependencies = [ "mime_guess", "model_derive", "openidconnect", + "parse_link_header", "paste", "pgp", "prost", @@ -2410,6 +2412,23 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebkey" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c57c852b14147e2bd58c14fde40398864453403ef632b1101db130282ee6e2cc" +dependencies = [ + "base64 0.13.1", + "bitflags 1.3.2", + "generic-array", + "num-bigint", + "serde", + "serde_json", + "thiserror 1.0.69", + "yasna", + "zeroize", +] + [[package]] name = "jsonwebtoken" version = "9.3.0" @@ -3094,6 +3113,17 @@ dependencies = [ "regex", ] +[[package]] +name = "parse_link_header" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc23fdb8bbf668d582b0c17120bf6b7f91d85ccad3a5b39706f019a4efda005" +dependencies = [ + "http 1.2.0", + "lazy_static", + "regex", +] + [[package]] name = "password-hash" version = "0.5.0" @@ -5965,6 +5995,15 @@ dependencies = [ "time", ] +[[package]] +name = "yasna" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e262a29d0e61ccf2b6190d7050d4b237535fc76ce4c1210d9caa316f71dffa75" +dependencies = [ + "num-bigint", +] + [[package]] name = "yoke" version = "0.7.5" diff --git a/Cargo.toml b/Cargo.toml index e339fca50..fe5402a7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "defguard" -version = "1.2.2" +version = "1.2.3" edition = "2021" license-file = "LICENSE.md" homepage = "https://defguard.net/" @@ -31,6 +31,7 @@ dotenvy = "0.15" humantime = "2.1" # match ipnetwork version from sqlx ipnetwork = { version = "0.20", features = ["serde"] } +jsonwebkey = { version = "0.3.5", features = ["pkcs-convert"] } jsonwebtoken = "9.3" ldap3 = { version = "0.11", default-features = false, features = ["tls"] } lettre = { version = "0.11", features = ["tokio1", "tokio1-native-tls"] } @@ -40,6 +41,7 @@ model_derive = { path = "model-derive" } openidconnect = { version = "3.5", default-features = false, optional = true, features = [ "reqwest", ] } +parse_link_header = "0.4" paste = "1.0.15" pgp = "0.14" prost = "0.13" diff --git a/migrations/20250129114956_okta_dirsync.down.sql b/migrations/20250129114956_okta_dirsync.down.sql new file mode 100644 index 000000000..27d5f9e6d --- /dev/null +++ b/migrations/20250129114956_okta_dirsync.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE openidprovider DROP COLUMN okta_private_jwk; +ALTER TABLE openidprovider DROP COLUMN okta_dirsync_client_id; diff --git a/migrations/20250129114956_okta_dirsync.up.sql b/migrations/20250129114956_okta_dirsync.up.sql new file mode 100644 index 000000000..af375f09b --- /dev/null +++ b/migrations/20250129114956_okta_dirsync.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE openidprovider ADD COLUMN okta_private_jwk TEXT DEFAULT NULL; +ALTER TABLE openidprovider ADD COLUMN okta_dirsync_client_id TEXT DEFAULT NULL; diff --git a/src/appstate.rs b/src/appstate.rs index 8a1b7788a..11573bbdf 100644 --- a/src/appstate.rs +++ b/src/appstate.rs @@ -18,6 +18,7 @@ use webauthn_rs::prelude::*; use crate::{ auth::failed_login::FailedLoginMap, db::{AppEvent, GatewayEvent, WebHook}, + grpc::gateway::{send_multiple_wireguard_events, send_wireguard_event}, mail::Mail, server_config, }; @@ -26,7 +27,7 @@ use crate::{ pub struct AppState { pub pool: PgPool, tx: UnboundedSender, - wireguard_tx: Sender, + pub wireguard_tx: Sender, pub mail_tx: UnboundedSender, pub webauthn: Arc, pub failed_logins: Arc>, @@ -79,19 +80,16 @@ impl AppState { } } - /// Sends given `GatewayEvent` to be handled by gateway GRPC server + /// Sends given `GatewayEvent` to be handled by gateway GRPC server. + /// Convenience wrapper around [`send_wireguard_event`] pub fn send_wireguard_event(&self, event: GatewayEvent) { - if let Err(err) = self.wireguard_tx.send(event) { - error!("Error sending WireGuard event {err}"); - } + send_wireguard_event(event, &self.wireguard_tx); } - /// Sends multiple events to be handled by gateway GRPC server + /// Sends multiple events to be handled by gateway GRPC server. + /// Convenience wrapper around [`send_multiple_wireguard_events`] pub fn send_multiple_wireguard_events(&self, events: Vec) { - debug!("Sending {} wireguard events", events.len()); - for event in events { - self.send_wireguard_event(event); - } + send_multiple_wireguard_events(events, &self.wireguard_tx); } /// Create application state diff --git a/src/bin/defguard.rs b/src/bin/defguard.rs index b74ddc79e..c9e2feb36 100644 --- a/src/bin/defguard.rs +++ b/src/bin/defguard.rs @@ -127,10 +127,10 @@ async fn main() -> Result<(), anyhow::Error> { res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), Arc::clone(&gateway_state), wireguard_tx.clone(), mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:?}"), res = run_web_server(worker_state, gateway_state, webhook_tx, webhook_rx, wireguard_tx.clone(), mail_tx, pool.clone(), failed_logins) => error!("Web server returned early: {res:?}"), res = run_mail_handler(mail_rx, pool.clone()) => error!("Mail handler returned early: {res:?}"), - res = run_periodic_peer_disconnect(pool.clone(), wireguard_tx) => error!("Periodic peer disconnect task returned early: {res:?}"), + res = run_periodic_peer_disconnect(pool.clone(), wireguard_tx.clone()) => error!("Periodic peer disconnect task returned early: {res:?}"), res = run_periodic_stats_purge(pool.clone(), config.stats_purge_frequency.into(), config.stats_purge_threshold.into()), if !config.disable_stats_purge => error!("Periodic stats purge task returned early: {res:?}"), res = run_periodic_license_check(&pool) => error!("Periodic license check task returned early: {res:?}"), - res = run_utility_thread(&pool) => error!("Utility thread returned early: {res:?}"), + res = run_utility_thread(&pool, wireguard_tx) => error!("Utility thread returned early: {res:?}"), } Ok(()) } diff --git a/src/db/models/device.rs b/src/db/models/device.rs index ad7622b89..ddb8a7638 100644 --- a/src/db/models/device.rs +++ b/src/db/models/device.rs @@ -401,6 +401,33 @@ impl WireguardNetworkDevice { Ok(res) } + /// Get all devices for a given network and user + /// Note: doesn't return network devices added by the user + /// as they are not considered to be bound to the user + pub(crate) async fn all_for_network_and_user<'e, E>( + executor: E, + network_id: Id, + user_id: Id, + ) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + let res = query_as!( + Self, + "SELECT device_id, wireguard_network_id, wireguard_ip \"wireguard_ip: IpAddr\", \ + preshared_key, is_authorized, authorized_at \ + FROM wireguard_network_device \ + WHERE wireguard_network_id = $1 AND device_id IN \ + (SELECT id FROM device WHERE user_id = $2 AND device_type = 'user'::device_type)", + network_id, + user_id + ) + .fetch_all(executor) + .await?; + + Ok(res) + } + pub(crate) async fn network<'e, E>( &self, executor: E, @@ -855,6 +882,8 @@ impl Device { #[cfg(test)] mod test { + use std::str::FromStr; + use claims::{assert_err, assert_ok}; use super::*; @@ -945,4 +974,123 @@ mod test { let valid_test_key = "sejIy0WCLvOR7vWNchP9Elsayp3UTK/QCnEJmhsHKTc="; assert_ok!(Device::validate_pubkey(valid_test_key)); } + + #[sqlx::test] + fn test_all_for_network_and_user(pool: PgPool) { + let user = User::new( + "testuser", + Some("hunter2"), + "Tester", + "Test", + "email@email.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "testuser2", + Some("hunter2"), + "Tester", + "Test", + "email2@email.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let mut network = WireguardNetwork::default(); + network.try_set_address("10.1.1.1/24").unwrap(); + let network = network.save(&pool).await.unwrap(); + let mut network2 = WireguardNetwork::default(); + network2.name = "testnetwork2".into(); + network2.try_set_address("10.1.2.1/24").unwrap(); + let network2 = network2.save(&pool).await.unwrap(); + + let device = Device::new( + "testdevice".into(), + "key".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device2 = Device::new( + "testdevice2".into(), + "key2".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device3 = Device::new( + "testdevice3".into(), + "key3".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device4 = Device::new( + "testdevice4".into(), + "key4".into(), + user.id, + DeviceType::Network, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + + network + .add_device_to_network(&mut transaction, &device, None) + .await + .unwrap(); + network2 + .add_device_to_network(&mut transaction, &device, None) + .await + .unwrap(); + network2 + .add_device_to_network(&mut transaction, &device2, None) + .await + .unwrap(); + network + .add_device_to_network(&mut transaction, &device3, None) + .await + .unwrap(); + WireguardNetworkDevice::new( + network.id, + device4.id, + IpAddr::from_str("10.1.1.10").unwrap(), + ) + .insert(&mut *transaction) + .await + .unwrap(); + + transaction.commit().await.unwrap(); + + let devices = WireguardNetworkDevice::all_for_network_and_user(&pool, network.id, user.id) + .await + .unwrap(); + + assert_eq!(devices.len(), 1); + assert_eq!(devices[0].device_id, device.id); + } } diff --git a/src/db/models/oauth2token.rs b/src/db/models/oauth2token.rs index f65fdf1dd..1639750ec 100644 --- a/src/db/models/oauth2token.rs +++ b/src/db/models/oauth2token.rs @@ -135,6 +135,7 @@ impl OAuth2Token { Err(err) => Err(err), } } + // Find by authorized app id pub async fn find_by_authorized_app_id( pool: &PgPool, diff --git a/src/db/models/settings.rs b/src/db/models/settings.rs index b204efb97..a05dc74e2 100644 --- a/src/db/models/settings.rs +++ b/src/db/models/settings.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, str::FromStr}; +use std::collections::HashMap; use sqlx::{query, query_as, PgExecutor, PgPool, Type}; use struct_patch::Patch; @@ -257,19 +257,15 @@ impl Settings { } /// Check if all required SMTP options are configured. + /// User & password can be empty for no-auth servers. /// /// Meant to be used to check if sending emails is enabled in current instance. #[must_use] pub fn smtp_configured(&self) -> bool { self.smtp_server.is_some() && self.smtp_port.is_some() - && self.smtp_user.is_some() - && self.smtp_password.is_some() && self.smtp_sender.is_some() && self.smtp_server != Some(String::new()) - && self.smtp_user != Some(String::new()) - && self.smtp_password - != Some(SecretStringWrapper::from_str("").expect("Failed to convert empty string")) && self.smtp_sender != Some(String::new()) } } @@ -353,3 +349,34 @@ Star us on GitHub! https://github.com/defguard/defguard\ pub static WELCOME_EMAIL_SUBJECT: &str = "[defguard] Welcome message after enrollment"; } + +#[cfg(test)] +mod test { + use std::str::FromStr; + + use super::*; + + #[test] + fn test_smtp_config() { + let mut settings = Settings::default(); + assert!(!settings.smtp_configured()); + + // incomplete SMTP config + settings.smtp_server = Some("localhost".into()); + settings.smtp_port = Some(587); + assert!(!settings.smtp_configured()); + + // no-auth SMTP config + settings.smtp_sender = Some("no-reply@defguard.net".into()); + assert!(settings.smtp_configured()); + + // add non-default encryption + settings.smtp_encryption = SmtpEncryption::StartTls; + assert!(settings.smtp_configured()); + + // add auth info + settings.smtp_user = Some("smtp_user".into()); + settings.smtp_password = Some(SecretStringWrapper::from_str("hunter2").unwrap()); + assert!(settings.smtp_configured()); + } +} diff --git a/src/db/models/user.rs b/src/db/models/user.rs index 892700882..4da345996 100644 --- a/src/db/models/user.rs +++ b/src/db/models/user.rs @@ -9,20 +9,26 @@ use argon2::{ }; use axum::http::StatusCode; use model_derive::Model; -use sqlx::{query, query_as, query_scalar, Error as SqlxError, FromRow, PgExecutor, PgPool, Type}; +use sqlx::{ + query, query_as, query_scalar, Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, + Type, +}; +use tokio::sync::broadcast::Sender; use totp_lite::{totp_custom, Sha1}; use utoipa::ToSchema; use super::{ - device::{Device, DeviceType, UserDevice}, + device::{Device, DeviceInfo, DeviceType, UserDevice}, group::Group, webauthn::WebAuthn, MFAInfo, OAuth2AuthorizedAppInfo, SecurityKey, }; use crate::{ auth::{EMAIL_CODE_DIGITS, TOTP_CODE_DIGITS, TOTP_CODE_VALIDITY_PERIOD}, - db::{models::group::Permission, Id, NoId, Session}, + db::{models::group::Permission, GatewayEvent, Id, NoId, Session, WireguardNetwork}, error::WebError, + grpc::gateway::send_multiple_wireguard_events, + ldap::utils::ldap_delete_user, random::{gen_alphanumeric, gen_totp_secret}, server_config, }; @@ -308,6 +314,65 @@ impl User { Ok(()) } + /// Disable user, log out all his sessions and update gateways state. + pub async fn disable( + &mut self, + transaction: &mut PgConnection, + wg_tx: &Sender, + ) -> Result<(), WebError> { + self.is_active = false; + self.save(&mut *transaction).await?; + self.logout_all_sessions(&mut *transaction).await?; + self.sync_allowed_devices(transaction, wg_tx).await?; + Ok(()) + } + + /// Update gateway state based on this user device access rights + pub async fn sync_allowed_devices( + &self, + transaction: &mut PgConnection, + wg_tx: &Sender, + ) -> Result<(), WebError> { + debug!("Syncing allowed devices of {}", self.username); + let networks = WireguardNetwork::all(&mut *transaction).await?; + for network in networks { + let gateway_events = network + .sync_allowed_devices_for_user(transaction, self, None) + .await?; + send_multiple_wireguard_events(gateway_events, wg_tx); + } + info!("Allowed devices of {} synced", self.username); + Ok(()) + } + + /// Deletes the user and cleans up his devices from gateways + pub async fn delete_and_cleanup( + self, + transaction: &mut PgConnection, + wg_tx: &Sender, + ) -> Result<(), WebError> { + let username = self.username.clone(); + debug!( + "Deleting user {}, removing his devices from gateways and updating ldap...", + &username + ); + let devices = self.devices(&mut *transaction).await?; + let mut events = Vec::new(); + for device in devices { + events.push(GatewayEvent::DeviceDeleted( + DeviceInfo::from_device(&mut *transaction, device).await?, + )); + } + self.delete(&mut *transaction).await?; + send_multiple_wireguard_events(events, wg_tx); + let _result = ldap_delete_user(&username).await; + info!( + "The user {} has been deleted and his devices removed from gateways.", + &username + ); + Ok(()) + } + /// Enable MFA. At least one of the authenticator factors must be configured. pub async fn enable_mfa(&mut self, pool: &PgPool) -> Result<(), WebError> { if !self.mfa_enabled { diff --git a/src/db/models/wireguard.rs b/src/db/models/wireguard.rs index 037685390..3788ced34 100644 --- a/src/db/models/wireguard.rs +++ b/src/db/models/wireguard.rs @@ -29,7 +29,10 @@ use super::{ use crate::{ appstate::AppState, db::{Id, NoId}, - grpc::{gateway::Peer, GatewayState}, + grpc::{ + gateway::{send_multiple_wireguard_events, Peer}, + GatewayState, + }, wg_config::ImportedDevice, }; @@ -213,7 +216,7 @@ impl WireguardNetwork { let networks = Self::all(&mut *transaction).await?; for network in networks { let gateway_events = network.sync_allowed_devices(&mut transaction, None).await?; - app.send_multiple_wireguard_events(gateway_events); + send_multiple_wireguard_events(gateway_events, &app.wireguard_tx); } transaction.commit().await?; Ok(()) @@ -255,6 +258,7 @@ impl WireguardNetwork { /// Get a list of all devices belonging to users in allowed groups. /// Admin users should always be allowed to access a network. + /// Note: Doesn't check if the devices are really in the network. async fn get_allowed_devices( &self, transaction: &mut PgConnection, @@ -300,6 +304,57 @@ impl WireguardNetwork { Ok(devices) } + /// Get a list of devices belonging to a user which are also in the network's allowed groups. + /// Admin users should always be allowed to access a network. + /// Note: Doesn't check if the devices are really in the network. + async fn get_allowed_devices_for_user( + &self, + transaction: &mut PgConnection, + user_id: Id, + ) -> Result>, ModelError> { + debug!("Fetching all allowed devices for network {}", self); + let devices = match self.get_allowed_groups(&mut *transaction).await? { + // devices need to be filtered by allowed group + Some(allowed_groups) => { + query_as!( + Device, + "SELECT DISTINCT ON (d.id) d.id, d.name, d.wireguard_pubkey, d.user_id, d.created, d.description, d.device_type \"device_type: DeviceType\", \ + configured + FROM device d \ + JOIN \"user\" u ON d.user_id = u.id \ + JOIN group_user gu ON u.id = gu.user_id \ + JOIN \"group\" g ON gu.group_id = g.id \ + WHERE g.\"name\" IN (SELECT * FROM UNNEST($1::text[])) \ + AND u.is_active = true \ + AND d.device_type = 'user'::device_type \ + AND d.user_id = $2 \ + ORDER BY d.id ASC", + &allowed_groups, user_id + ) + .fetch_all(&mut *transaction) + .await? + } + // all devices of enabled users are allowed + None => { + query_as!( + Device, + "SELECT d.id, d.name, d.wireguard_pubkey, d.user_id, d.created, d.description, d.device_type \"device_type: DeviceType\", \ + configured \ + FROM device d \ + JOIN \"user\" u ON d.user_id = u.id \ + WHERE u.is_active = true \ + AND d.device_type = 'user'::device_type \ + AND d.user_id = $1 \ + ORDER BY d.id ASC", user_id + ) + .fetch_all(&mut *transaction) + .await? + } + }; + + Ok(devices) + } + /// Generate network IPs for all existing devices /// If `allowed_groups` is set, devices should be filtered accordingly pub(crate) async fn add_all_allowed_devices( @@ -355,42 +410,19 @@ impl WireguardNetwork { Ok(wireguard_network_device) } - /// Refresh network IPs for all relevant devices - /// If the list of allowed devices has changed add/remove devices accordingly - /// If the network address has changed readdress existing devices - pub(crate) async fn sync_allowed_devices( + /// Works out which devices need to be added, removed, or readdressed + /// based on the list of currently configured devices and the list of devices which should be allowed + async fn process_device_access_changes( &self, transaction: &mut PgConnection, + mut allowed_devices: HashMap>, + currently_configured_devices: Vec, reserved_ips: Option<&[IpAddr]>, ) -> Result, WireguardNetworkError> { - info!("Synchronizing IPs in network {self} for all allowed devices "); - // list all allowed devices - let mut allowed_devices = self.get_allowed_devices(&mut *transaction).await?; - // network devices are always allowed, make sure to take only network devices already assigned to that network - let network_devices = - Device::find_by_type_and_network(&mut *transaction, DeviceType::Network, self.id) - .await?; - allowed_devices.extend(network_devices); - - // convert to a map for easier processing - let mut allowed_devices: HashMap> = allowed_devices - .into_iter() - .map(|dev| (dev.id, dev)) - .collect(); - - // check if all devices can fit within network - // include address, network, and broadcast in the calculation - let count = allowed_devices.len() + 3; - self.validate_network_size(count)?; - - // list all assigned IPs - let assigned_ips = - WireguardNetworkDevice::all_for_network(&mut *transaction, self.id).await?; - - // loop through assigned IPs; remove no longer allowed, readdress when necessary; remove processed entry from all devices list + // loop through current device configurations; remove no longer allowed, readdress when necessary; remove processed entry from all devices list // initial list should now contain only devices to be added - let mut events = Vec::new(); - for device_network_config in assigned_ips { + let mut events: Vec = Vec::new(); + for device_network_config in currently_configured_devices { // device is allowed and an IP was already assigned if let Some(device) = allowed_devices.remove(&device_network_config.device_id) { // network address changed and IP needs to be updated @@ -454,6 +486,93 @@ impl WireguardNetwork { Ok(events) } + /// Refresh network IPs for all relevant devices of a given user + /// If the list of allowed devices has changed add/remove devices accordingly + /// If the network address has changed readdress existing devices + pub(crate) async fn sync_allowed_devices_for_user( + &self, + transaction: &mut PgConnection, + user: &User, + reserved_ips: Option<&[IpAddr]>, + ) -> Result, WireguardNetworkError> { + info!("Synchronizing IPs in network {self} for all allowed devices "); + // list all allowed devices + let allowed_devices = self + .get_allowed_devices_for_user(&mut *transaction, user.id) + .await?; + + // convert to a map for easier processing + let allowed_devices: HashMap> = allowed_devices + .into_iter() + .map(|dev| (dev.id, dev)) + .collect(); + + // check if all devices can fit within network + // include address, network, and broadcast in the calculation + let count = allowed_devices.len() + 3; + self.validate_network_size(count)?; + + // list all assigned IPs + let assigned_ips = + WireguardNetworkDevice::all_for_network_and_user(&mut *transaction, self.id, user.id) + .await?; + + let events = self + .process_device_access_changes( + &mut *transaction, + allowed_devices, + assigned_ips, + reserved_ips, + ) + .await?; + + Ok(events) + } + + /// Refresh network IPs for all relevant devices + /// If the list of allowed devices has changed add/remove devices accordingly + /// If the network address has changed readdress existing devices + pub(crate) async fn sync_allowed_devices( + &self, + transaction: &mut PgConnection, + reserved_ips: Option<&[IpAddr]>, + ) -> Result, WireguardNetworkError> { + info!("Synchronizing IPs in network {self} for all allowed devices "); + // list all allowed devices + let mut allowed_devices = self.get_allowed_devices(&mut *transaction).await?; + // network devices are always allowed, make sure to take only network devices already assigned to that network + let network_devices = + Device::find_by_type_and_network(&mut *transaction, DeviceType::Network, self.id) + .await?; + allowed_devices.extend(network_devices); + + // convert to a map for easier processing + let allowed_devices: HashMap> = allowed_devices + .into_iter() + .map(|dev| (dev.id, dev)) + .collect(); + + // check if all devices can fit within network + // include address, network, and broadcast in the calculation + let count = allowed_devices.len() + 3; + self.validate_network_size(count)?; + + // list all assigned IPs + let assigned_ips = + WireguardNetworkDevice::all_for_network(&mut *transaction, self.id).await?; + + let events = self + .process_device_access_changes( + &mut *transaction, + allowed_devices, + assigned_ips, + reserved_ips, + ) + .await?; + + Ok(events) + } + /// Check if devices found in an imported config file exist already, /// if they do assign a specified IP. /// Return a list of imported devices which need to be manually mapped to a user @@ -1041,6 +1160,7 @@ mod test { use chrono::{SubsecRound, TimeDelta}; use super::*; + use crate::db::Group; #[sqlx::test] async fn test_connected_at_reconnection(pool: PgPool) { @@ -1165,4 +1285,421 @@ mod test { (now - TimeDelta::minutes(samples)).trunc_subsecs(6), ); } + + #[sqlx::test] + async fn test_get_allowed_devices_for_user(pool: PgPool) { + let mut network = WireguardNetwork::default(); + network.try_set_address("10.1.1.1/29").unwrap(); + let network = network.save(&pool).await.unwrap(); + + let user1 = User::new( + "user1", + Some("pass1"), + "Test", + "User1", + "user1@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "user2", + Some("pass2"), + "Test", + "User2", + "user2@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device1 = Device::new( + "device1".into(), + "key1".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device2 = Device::new( + "device2".into(), + "key2".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device3 = Device::new( + "device3".into(), + "key3".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let devices = network + .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), user1.id) + .await + .unwrap(); + assert_eq!(devices.len(), 2); + assert!(devices.iter().any(|d| d.id == device1.id)); + assert!(devices.iter().any(|d| d.id == device2.id)); + + let devices = network + .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), user2.id) + .await + .unwrap(); + assert_eq!(devices.len(), 1); + assert!(devices.iter().any(|d| d.id == device3.id)); + + let devices = network + .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), Id::from(999)) + .await + .unwrap(); + assert!(devices.is_empty()); + } + + #[sqlx::test] + async fn test_get_allowed_devices_for_user_with_groups(pool: PgPool) { + let mut network = WireguardNetwork::default(); + network.try_set_address("10.1.1.1/29").unwrap(); + let network = network.save(&pool).await.unwrap(); + + let user1 = User::new( + "user1", + Some("pass1"), + "Test", + "User1", + "user1@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "user2", + Some("pass2"), + "Test", + "User2", + "user2@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let group1 = Group::new("group1").save(&pool).await.unwrap(); + let group2 = Group::new("group2").save(&pool).await.unwrap(); + + user1.add_to_group(&pool, &group1).await.unwrap(); + user2.add_to_group(&pool, &group2).await.unwrap(); + + let device1 = Device::new( + "device1".into(), + "key1".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + Device::new( + "device2".into(), + "key2".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + + network + .set_allowed_groups(&mut transaction, vec![group1.name]) + .await + .unwrap(); + + let devices = network + .get_allowed_devices_for_user(&mut transaction, user1.id) + .await + .unwrap(); + assert_eq!(devices.len(), 1); + assert_eq!(devices[0].id, device1.id); + + let devices = network + .get_allowed_devices_for_user(&mut transaction, user2.id) + .await + .unwrap(); + assert!(devices.is_empty()); + } + + #[sqlx::test] + async fn test_sync_allowed_devices_for_user(pool: PgPool) { + let mut network = WireguardNetwork::default(); + network.try_set_address("10.1.1.1/29").unwrap(); + let network = network.save(&pool).await.unwrap(); + + let user1 = User::new( + "testuser1", + Some("pass1"), + "Tester1", + "Test1", + "test1@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "testuser2", + Some("pass2"), + "Tester2", + "Test2", + "test2@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device1 = Device::new( + "device1".into(), + "key1".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device2 = Device::new( + "device2".into(), + "key2".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device3 = Device::new( + "device3".into(), + "key3".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + + // user1 sync + let events = network + .sync_allowed_devices_for_user(&mut transaction, &user1, None) + .await + .unwrap(); + + assert_eq!(events.len(), 2); + assert!(events.iter().any(|e| match e { + GatewayEvent::DeviceCreated(info) => info.device.id == device1.id, + _ => false, + })); + assert!(events.iter().any(|e| match e { + GatewayEvent::DeviceCreated(info) => info.device.id == device2.id, + _ => false, + })); + + // user 2 sync + let events = network + .sync_allowed_devices_for_user(&mut transaction, &user2, None) + .await + .unwrap(); + + assert_eq!(events.len(), 1); + match &events[0] { + GatewayEvent::DeviceCreated(info) => { + assert_eq!(info.device.id, device3.id); + } + _ => panic!("Expected DeviceCreated event"), + } + + // Second sync should not generate any events + let events = network + .sync_allowed_devices_for_user(&mut transaction, &user1, None) + .await + .unwrap(); + assert_eq!(events.len(), 0); + + transaction.commit().await.unwrap(); + } + + #[sqlx::test] + async fn test_sync_allowed_devices_for_user_with_groups(pool: PgPool) { + let mut network = WireguardNetwork::default(); + network.try_set_address("10.1.1.1/29").unwrap(); + let network = network.save(&pool).await.unwrap(); + + let user1 = User::new( + "testuser1", + Some("pass1"), + "Tester1", + "Test1", + "test1@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "testuser2", + Some("pass2"), + "Tester2", + "Test2", + "test2@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user3 = User::new( + "testuser3", + Some("pass3"), + "Tester3", + "Test3", + "test3@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device1 = Device::new( + "device1".into(), + "key1".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device2 = Device::new( + "device2".into(), + "key2".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device3 = Device::new( + "device3".into(), + "key3".into(), + user3.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let group1 = Group::new("group1").save(&pool).await.unwrap(); + let group2 = Group::new("group2").save(&pool).await.unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + + network + .set_allowed_groups( + &mut transaction, + vec![group1.name.clone(), group2.name.clone()], + ) + .await + .unwrap(); + + let events = network + .sync_allowed_devices_for_user(&mut transaction, &user1, None) + .await + .unwrap(); + assert_eq!(events.len(), 0); + + user1.add_to_group(&pool, &group1).await.unwrap(); + user2.add_to_group(&pool, &group1).await.unwrap(); + user3.add_to_group(&pool, &group2).await.unwrap(); + + let events = network + .sync_allowed_devices_for_user(&mut transaction, &user1, None) + .await + .unwrap(); + assert_eq!(events.len(), 1); + match &events[0] { + GatewayEvent::DeviceCreated(info) => { + assert_eq!(info.device.id, device1.id); + } + _ => panic!("Expected DeviceCreated event"), + } + + let events = network + .sync_allowed_devices_for_user(&mut transaction, &user2, None) + .await + .unwrap(); + assert_eq!(events.len(), 1); + match &events[0] { + GatewayEvent::DeviceCreated(info) => { + assert_eq!(info.device.id, device2.id); + } + _ => panic!("Expected DeviceCreated event"), + } + + let events = network + .sync_allowed_devices_for_user(&mut transaction, &user3, None) + .await + .unwrap(); + assert_eq!(events.len(), 1); + match &events[0] { + GatewayEvent::DeviceCreated(info) => { + assert_eq!(info.device.id, device3.id); + } + _ => panic!("Expected DeviceCreated event"), + } + + transaction.commit().await.unwrap(); + } } diff --git a/src/enterprise/db/models/openid_provider.rs b/src/enterprise/db/models/openid_provider.rs index ed6d09c20..ca511eb0a 100644 --- a/src/enterprise/db/models/openid_provider.rs +++ b/src/enterprise/db/models/openid_provider.rs @@ -106,6 +106,10 @@ pub struct OpenIdProvider { pub directory_sync_admin_behavior: DirectorySyncUserBehavior, #[model(enum)] pub directory_sync_target: DirectorySyncTarget, + // Specific stuff for Okta + pub okta_private_jwk: Option, + // The client ID of the directory sync app specifically + pub okta_dirsync_client_id: Option, } impl OpenIdProvider { @@ -124,6 +128,8 @@ impl OpenIdProvider { directory_sync_user_behavior: DirectorySyncUserBehavior, directory_sync_admin_behavior: DirectorySyncUserBehavior, directory_sync_target: DirectorySyncTarget, + okta_private_jwk: Option, + okta_dirsync_client_id: Option, ) -> Self { Self { id: NoId, @@ -140,6 +146,8 @@ impl OpenIdProvider { directory_sync_user_behavior, directory_sync_admin_behavior, directory_sync_target, + okta_private_jwk, + okta_dirsync_client_id, } } @@ -149,9 +157,10 @@ impl OpenIdProvider { "UPDATE openidprovider SET name = $1, \ base_url = $2, client_id = $3, client_secret = $4, \ display_name = $5, google_service_account_key = $6, google_service_account_email = $7, admin_email = $8, \ - directory_sync_enabled = $9, directory_sync_interval = $10, directory_sync_user_behavior = $11, directory_sync_admin_behavior = $12, \ - directory_sync_target = $13 \ - WHERE id = $14", + directory_sync_enabled = $9, directory_sync_interval = $10, directory_sync_user_behavior = $11, \ + directory_sync_admin_behavior = $12, directory_sync_target = $13, \ + okta_private_jwk = $14, okta_dirsync_client_id = $15 \ + WHERE id = $16", self.name, self.base_url, self.client_id, @@ -165,6 +174,8 @@ impl OpenIdProvider { self.directory_sync_user_behavior as DirectorySyncUserBehavior, self.directory_sync_admin_behavior as DirectorySyncUserBehavior, self.directory_sync_target as DirectorySyncTarget, + self.okta_private_jwk, + self.okta_dirsync_client_id, provider.id, ) .execute(pool) @@ -185,7 +196,8 @@ impl OpenIdProvider { google_service_account_key, google_service_account_email, admin_email, directory_sync_enabled, directory_sync_interval, directory_sync_user_behavior \"directory_sync_user_behavior: DirectorySyncUserBehavior\", \ directory_sync_admin_behavior \"directory_sync_admin_behavior: DirectorySyncUserBehavior\", \ - directory_sync_target \"directory_sync_target: DirectorySyncTarget\" \ + directory_sync_target \"directory_sync_target: DirectorySyncTarget\", \ + okta_private_jwk, okta_dirsync_client_id \ FROM openidprovider WHERE name = $1", name ) @@ -200,7 +212,8 @@ impl OpenIdProvider { google_service_account_key, google_service_account_email, admin_email, directory_sync_enabled, \ directory_sync_interval, directory_sync_user_behavior \"directory_sync_user_behavior: DirectorySyncUserBehavior\", \ directory_sync_admin_behavior \"directory_sync_admin_behavior: DirectorySyncUserBehavior\", \ - directory_sync_target \"directory_sync_target: DirectorySyncTarget\" \ + directory_sync_target \"directory_sync_target: DirectorySyncTarget\", \ + okta_private_jwk, okta_dirsync_client_id \ FROM openidprovider LIMIT 1" ) .fetch_optional(pool) diff --git a/src/enterprise/directory_sync/google.rs b/src/enterprise/directory_sync/google.rs index b7b226740..b44e61621 100644 --- a/src/enterprise/directory_sync/google.rs +++ b/src/enterprise/directory_sync/google.rs @@ -1,21 +1,22 @@ -use std::{str::FromStr, time::Duration}; +use std::collections::HashMap; use chrono::{DateTime, TimeDelta, Utc}; -#[cfg(not(test))] use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; -use reqwest::{header::AUTHORIZATION, Url}; +use tokio::time::sleep; -use super::{parse_response, DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser}; +use super::{ + make_get_request, parse_response, DirectoryGroup, DirectorySync, DirectorySyncError, + DirectoryUser, REQUEST_PAGINATION_SLOWDOWN, REQUEST_TIMEOUT, +}; -#[cfg(not(test))] const SCOPES: &str = "openid email profile https://www.googleapis.com/auth/admin.directory.customer.readonly https://www.googleapis.com/auth/admin.directory.group.readonly https://www.googleapis.com/auth/admin.directory.user.readonly"; const ACCESS_TOKEN_URL: &str = "https://oauth2.googleapis.com/token"; const GROUPS_URL: &str = "https://admin.googleapis.com/admin/directory/v1/groups"; const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; -#[cfg(not(test))] const AUD: &str = "https://oauth2.googleapis.com/token"; const ALL_USERS_URL: &str = "https://admin.googleapis.com/admin/directory/v1/users"; -const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); +const MAX_REQUESTS: usize = 50; +const MAX_RESULTS: &str = "200"; #[derive(Debug, Serialize, Deserialize)] struct Claims { @@ -27,10 +28,8 @@ struct Claims { iat: i64, } -#[cfg(not(test))] impl Claims { #[must_use] - #[cfg(not(test))] fn new(iss: &str, sub: &str) -> Self { let now = Utc::now(); let now_timestamp = now.timestamp(); @@ -60,9 +59,7 @@ pub(crate) struct GoogleDirectorySync { admin_email: String, } -/// /// Google Directory API responses -/// #[derive(Debug, Serialize, Deserialize)] struct AccessTokenResponse { @@ -77,9 +74,24 @@ struct GroupMember { status: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] struct GroupMembersResponse { members: Option>, + #[serde(rename = "nextPageToken")] + page_token: Option, +} + +impl From for Vec { + fn from(val: GroupMembersResponse) -> Self { + val.members + .unwrap_or_default() + .into_iter() + // There may be arbitrary members in the group, we want only one that are also directory members + // Members without a status field don't belong to the directory + .filter(|m| m.status.is_some()) + .map(|m| m.email) + .collect() + } } #[derive(Debug, Serialize, Deserialize)] @@ -91,21 +103,31 @@ struct User { impl From for DirectoryUser { fn from(val: User) -> Self { - DirectoryUser { + Self { email: val.primary_email, active: !val.suspended, } } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] struct UsersResponse { users: Vec, + #[serde(rename = "nextPageToken")] + page_token: Option, } -#[derive(Debug, Serialize, Deserialize)] +impl From for Vec { + fn from(val: UsersResponse) -> Self { + val.users.into_iter().map(Into::into).collect() + } +} + +#[derive(Debug, Serialize, Deserialize, Default)] struct GroupsResponse { groups: Vec, + #[serde(rename = "nextPageToken")] + page_token: Option, } impl GoogleDirectorySync { @@ -141,26 +163,21 @@ impl GoogleDirectorySync { .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; - let mut url = Url::from_str(ALL_USERS_URL).unwrap(); - url.query_pairs_mut() - .append_pair("customer", "my_customer") - .append_pair("maxResults", "1") - .append_pair("showDeleted", "false"); - let client = reqwest::Client::builder().build()?; - let result = client - .get(url) - .header(AUTHORIZATION, format!("Bearer {access_token}")) - .timeout(REQUEST_TIMEOUT) - .send() - .await?; + let response = make_get_request( + ALL_USERS_URL, + access_token, + Some(&[ + ("customer", "my_customer"), + ("maxResults", MAX_RESULTS), + ("showDeleted", "false"), + ]), + ) + .await?; let _result: UsersResponse = - parse_response(result, "Failed to test connection to Google API.").await?; + parse_response(response, "Failed to test connection to Google API.").await?; Ok(()) } -} -#[cfg(not(test))] -impl GoogleDirectorySync { async fn query_user_groups(&self, user_id: &str) -> Result { if self.is_token_expired() { return Err(DirectorySyncError::AccessTokenExpired); @@ -169,18 +186,45 @@ impl GoogleDirectorySync { .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; - let mut url = Url::from_str(GROUPS_URL).unwrap(); - url.query_pairs_mut() - .append_pair("userKey", user_id) - .append_pair("maxResults", "500"); - let client = reqwest::Client::new(); - let response = client - .get(url) - .header(AUTHORIZATION, format!("Bearer {access_token}")) - .timeout(REQUEST_TIMEOUT) - .send() + let mut combined_response = GroupsResponse::default(); + let mut query = HashMap::from([ + ("userKey".to_string(), user_id.to_string()), + ("maxResults".to_string(), MAX_RESULTS.to_string()), + ]); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request( + GROUPS_URL, + access_token, + Some( + &query + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect::>(), + ), + ) .await?; - parse_response(response, "Failed to query user groups from Google API.").await + let response: GroupsResponse = + parse_response(response, "Failed to query user groups from Google API.").await?; + + if combined_response.groups.is_empty() { + combined_response.groups = response.groups; + } else { + combined_response.groups.extend(response.groups); + } + + if let Some(next_page_token) = response.page_token { + debug!("Found next page of results, using the following token to query it: {next_page_token}"); + query.insert("pageToken".to_string(), next_page_token); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) } async fn query_groups(&self) -> Result { @@ -192,20 +236,45 @@ impl GoogleDirectorySync { .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; - let mut url = Url::from_str(GROUPS_URL).unwrap(); - - url.query_pairs_mut() - .append_pair("customer", "my_customer") - .append_pair("maxResults", "500"); - - let client = reqwest::Client::builder().build()?; - let response = client - .get(url) - .header(AUTHORIZATION, format!("Bearer {access_token}")) - .timeout(REQUEST_TIMEOUT) - .send() + let mut combined_response = GroupsResponse::default(); + let mut query = HashMap::from([ + ("customer".to_string(), "my_customer".to_string()), + ("maxResults".to_string(), MAX_RESULTS.to_string()), + ]); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request( + GROUPS_URL, + access_token, + Some( + &query + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect::>(), + ), + ) .await?; - parse_response(response, "Failed to query groups from Google API.").await + let response: GroupsResponse = + parse_response(response, "Failed to query groups from Google API.").await?; + + if combined_response.groups.is_empty() { + combined_response.groups = response.groups; + } else { + combined_response.groups.extend(response.groups); + } + + if let Some(next_page_token) = response.page_token { + debug!("Found next page of results, using the following token to query it: {next_page_token}"); + query.insert("pageToken".to_string(), next_page_token); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) } async fn query_group_members( @@ -220,26 +289,54 @@ impl GoogleDirectorySync { .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; - let url_str = format!( + let url = format!( "https://admin.googleapis.com/admin/directory/v1/groups/{}/members", group.id ); - let mut url = - Url::parse(&url_str).map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; - url.query_pairs_mut() - .append_pair("includeDerivedMembership", "true") - .append_pair("maxResults", "500"); - let client = reqwest::Client::builder().build()?; - let response = client - .get(url) - .header(AUTHORIZATION, format!("Bearer {access_token}")) - .timeout(REQUEST_TIMEOUT) - .send() + let mut combined_response = GroupMembersResponse::default(); + let mut query = HashMap::from([ + ("includeDerivedMembership".to_string(), "true".to_string()), + ("maxResults".to_string(), MAX_RESULTS.to_string()), + ]); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request( + &url, + access_token, + Some( + &query + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect::>(), + ), + ) .await?; - parse_response(response, "Failed to query group members from Google API.").await + let response: GroupMembersResponse = + parse_response(response, "Failed to query group members from Google API.").await?; + + if combined_response.members.is_none() { + combined_response.members = response.members; + } else { + combined_response.members = combined_response.members.map(|mut members| { + members.extend(response.members.unwrap_or_default()); + members + }); + } + + if let Some(next_page_token) = response.page_token { + debug!("Found next page of results, using the following token to query it: {next_page_token}"); + query.insert("pageToken".to_string(), next_page_token); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) } - #[cfg(not(test))] fn build_token(&self) -> Result { let claims = Claims::new(&self.service_account_config.client_email, &self.admin_email); let key = EncodingKey::from_rsa_pem(self.service_account_config.private_key.as_bytes())?; @@ -249,14 +346,12 @@ impl GoogleDirectorySync { async fn query_access_token(&self) -> Result { let token = self.build_token()?; - let mut url = Url::parse(ACCESS_TOKEN_URL).unwrap(); - url.query_pairs_mut() - .append_pair("grant_type", GRANT_TYPE) - .append_pair("assertion", &token); - let client = reqwest::Client::builder().build()?; + let client = reqwest::Client::new(); let response = client - .post(url) + .post(ACCESS_TOKEN_URL) + .query(&[("grant_type", GRANT_TYPE), ("assertion", &token)]) .header(reqwest::header::CONTENT_LENGTH, 0) + .timeout(REQUEST_TIMEOUT) .send() .await?; parse_response(response, "Failed to get access token from Google API.").await @@ -270,19 +365,46 @@ impl GoogleDirectorySync { .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; - let mut url = Url::from_str(ALL_USERS_URL).unwrap(); - url.query_pairs_mut() - .append_pair("customer", "my_customer") - .append_pair("maxResults", "500") - .append_pair("showDeleted", "false"); - let client = reqwest::Client::builder().build()?; - let response = client - .get(url) - .header(AUTHORIZATION, format!("Bearer {access_token}")) - .timeout(REQUEST_TIMEOUT) - .send() + let mut combined_response = UsersResponse::default(); + let mut query = HashMap::from([ + ("customer".to_string(), "my_customer".to_string()), + ("maxResults".to_string(), MAX_RESULTS.to_string()), + ("showDeleted".to_string(), "false".to_string()), + ]); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request( + ALL_USERS_URL, + access_token, + Some( + &query + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect::>(), + ), + ) .await?; - parse_response(response, "Failed to query all users in the Google API.").await + let response: UsersResponse = + parse_response(response, "Failed to query all users in the Google API.").await?; + + if combined_response.users.is_empty() { + combined_response.users = response.users; + } else { + combined_response.users.extend(response.users); + } + + if let Some(next_page_token) = response.page_token { + debug!("Found next page of results, using the following token to query it: {next_page_token}"); + query.insert("pageToken".to_string(), next_page_token); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) } } @@ -310,16 +432,11 @@ impl DirectorySync for GoogleDirectorySync { ) -> Result, DirectorySyncError> { debug!("Getting group members of group {}", group.name); let response = self.query_group_members(group).await?; - debug!("Got group members response for group {}", group.name); - Ok(response - .members - .unwrap_or_default() - .into_iter() - // There may be arbitrary members in the group, we want only one that are also directory members - // Members without a status field don't belong to the directory - .filter(|m| m.status.is_some()) - .map(|m| m.email) - .collect()) + debug!( + "Got group members response for group {}. Extracting their email addresses...", + group.name + ); + Ok(response.into()) } async fn prepare(&mut self) -> Result<(), DirectorySyncError> { @@ -339,7 +456,7 @@ impl DirectorySync for GoogleDirectorySync { debug!("Getting all users"); let response = self.query_all_users().await?; debug!("Got all users response"); - Ok(response.users.into_iter().map(Into::into).collect()) + Ok(response.into()) } async fn test_connection(&self) -> Result<(), DirectorySyncError> { @@ -351,222 +468,84 @@ impl DirectorySync for GoogleDirectorySync { } #[cfg(test)] -impl GoogleDirectorySync { - async fn query_user_groups(&self, user_id: &str) -> Result { - if self.is_token_expired() { - return Err(DirectorySyncError::AccessTokenExpired); - } - - let _access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - let mut url = Url::from_str(GROUPS_URL).expect("Invalid USER_GROUPS_URL has been set."); - - url.query_pairs_mut() - .append_pair("userKey", user_id) - .append_pair("max_results", "999"); - - Ok(GroupsResponse { - groups: vec![DirectoryGroup { - id: "1".into(), - name: "group1".into(), - }], - }) - } +mod tests { + use super::*; - async fn query_groups(&self) -> Result { - if self.is_token_expired() { - return Err(DirectorySyncError::AccessTokenExpired); - } + #[tokio::test] + async fn test_token() { + let mut dirsync = GoogleDirectorySync::new("private_key", "client_email", "admin_email"); - let _access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - let mut url = Url::from_str(GROUPS_URL).expect("Invalid USER_GROUPS_URL has been set."); + // no token + assert!(dirsync.is_token_expired()); - url.query_pairs_mut() - .append_pair("customer", "my_customer") - .append_pair("max_results", "999"); + // expired token + dirsync.access_token = Some("test_token".into()); + dirsync.token_expiry = Some(Utc::now() - TimeDelta::seconds(10000)); + assert!(dirsync.is_token_expired()); - Ok(GroupsResponse { - groups: vec![ - DirectoryGroup { - id: "1".into(), - name: "group1".into(), - }, - DirectoryGroup { - id: "2".into(), - name: "group2".into(), - }, - DirectoryGroup { - id: "3".into(), - name: "group3".into(), - }, - ], - }) + // valid token + dirsync.access_token = Some("test_token".into()); + dirsync.token_expiry = Some(Utc::now() + TimeDelta::seconds(10000)); + assert!(!dirsync.is_token_expired()); } - async fn query_group_members( - &self, - group: &DirectoryGroup, - ) -> Result { - if self.is_token_expired() { - return Err(DirectorySyncError::AccessTokenExpired); - } - - let _access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - - let url_str = format!( - "https://admin.googleapis.com/admin/directory/v1/groups/{}/members", - group.id - ); - let mut url = Url::from_str(&url_str).expect("Invalid GROUP_MEMBERS_URL has been set."); - url.query_pairs_mut() - .append_pair("includeDerivedMembership", "true"); - - Ok(GroupMembersResponse { + #[tokio::test] + async fn test_group_members_parse() { + let response = GroupMembersResponse { members: Some(vec![ GroupMember { - email: "testuser@email.com".into(), - status: Some("ACTIVE".into()), + email: "email@email.com".into(), + status: Some("active".into()), + }, + GroupMember { + email: "email2@email.com".into(), + status: Some("active".into()), }, GroupMember { - email: "testuserdisabled@email.com".into(), - status: Some("SUSPENDED".into()), + email: "email3@email.com".into(), + status: Some("suspended".into()), }, GroupMember { - email: "testuser2@email.com".into(), - status: Some("ACTIVE".into()), + email: "email4@email.com".into(), + status: None, }, ]), - }) - } + page_token: None, + }; - async fn query_access_token(&self) -> Result { - let mut url: Url = ACCESS_TOKEN_URL - .parse() - .expect("Invalid ACCESS_TOKEN_URL has been set."); - url.query_pairs_mut() - .append_pair("grant_type", GRANT_TYPE) - .append_pair("assertion", "test_assertion"); - Ok(AccessTokenResponse { - token: "test_token_refreshed".into(), - expires_in: 3600, - }) + let members: Vec = response.into(); + assert_eq!(members.len(), 3); + assert!(members.contains(&"email@email.com".into())); + assert!(members.contains(&"email2@email.com".into())); + assert!(members.contains(&"email3@email.com".into())); } - async fn query_all_users(&self) -> Result { - if self.is_token_expired() { - return Err(DirectorySyncError::AccessTokenExpired); - } - let _access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - let mut url = Url::from_str("https://admin.googleapis.com/admin/directory/v1/users") - .expect("Invalid USERS_URL has been set."); - url.query_pairs_mut().append_pair("customer", "my_customer"); - - Ok(UsersResponse { + #[tokio::test] + async fn test_all_users_parse() { + let response = UsersResponse { users: vec![ User { - primary_email: "testuser@email.com".into(), + primary_email: "email@email.com".into(), suspended: false, }, User { - primary_email: "testuserdisabled@email.com".into(), + primary_email: "email2@email.com".into(), suspended: true, }, User { - primary_email: "testuser2@email.com".into(), + primary_email: "email3@email.com".into(), suspended: false, }, ], - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_token() { - let mut dirsync = GoogleDirectorySync::new("private_key", "client_email", "admin_email"); - - // no token - assert!(dirsync.is_token_expired()); - - // expired token - dirsync.access_token = Some("test_token".into()); - dirsync.token_expiry = Some(Utc::now() - TimeDelta::seconds(10000)); - assert!(dirsync.is_token_expired()); - - // valid token - dirsync.access_token = Some("test_token".into()); - dirsync.token_expiry = Some(Utc::now() + TimeDelta::seconds(10000)); - assert!(!dirsync.is_token_expired()); - - // no token - dirsync.access_token = Some("test_token".into()); - dirsync.token_expiry = Some(Utc::now() - TimeDelta::seconds(10000)); - dirsync.refresh_access_token().await.unwrap(); - assert!(!dirsync.is_token_expired()); - assert_eq!(dirsync.access_token, Some("test_token_refreshed".into())); - } - - #[tokio::test] - async fn test_all_users() { - let mut dirsync = GoogleDirectorySync::new("private_key", "client_email", "admin_email"); - dirsync.refresh_access_token().await.unwrap(); - - let users = dirsync.get_all_users().await.unwrap(); + page_token: None, + }; + let users: Vec = response.into(); assert_eq!(users.len(), 3); - assert_eq!(users[1].email, "testuserdisabled@email.com"); - assert!(!users[1].active); - } - - #[tokio::test] - async fn test_groups() { - let mut dirsync = GoogleDirectorySync::new("private_key", "client_email", "admin_email"); - dirsync.refresh_access_token().await.unwrap(); - - let groups = dirsync.get_groups().await.unwrap(); - - assert_eq!(groups.len(), 3); - - for (i, group) in groups.iter().enumerate().take(3) { - assert_eq!(group.id, (i + 1).to_string()); - assert_eq!(group.name, format!("group{}", i + 1)); - } - } - - #[tokio::test] - async fn test_user_groups() { - let mut dirsync = GoogleDirectorySync::new("private_key", "client_email", "admin_email"); - dirsync.refresh_access_token().await.unwrap(); - - let groups = dirsync.get_user_groups("testuser").await.unwrap(); - assert_eq!(groups.len(), 1); - assert_eq!(groups[0].id, "1"); - assert_eq!(groups[0].name, "group1"); - } - - #[tokio::test] - async fn test_group_members() { - let mut dirsync = GoogleDirectorySync::new("private_key", "client_email", "admin_email"); - dirsync.refresh_access_token().await.unwrap(); - - let groups = dirsync.get_groups().await.unwrap(); - let members = dirsync.get_group_members(&groups[0]).await.unwrap(); - - assert_eq!(members.len(), 3); - assert_eq!(members[0], "testuser@email.com"); + let disabled_user = users + .iter() + .find(|u| u.email == "email2@email.com") + .unwrap(); + assert!(!disabled_user.active); } } diff --git a/src/enterprise/directory_sync/microsoft.rs b/src/enterprise/directory_sync/microsoft.rs index c9c2036e2..13f4776a7 100644 --- a/src/enterprise/directory_sync/microsoft.rs +++ b/src/enterprise/directory_sync/microsoft.rs @@ -1,10 +1,12 @@ -use std::time::Duration; - use chrono::{TimeDelta, Utc}; -use reqwest::{header::AUTHORIZATION, Url}; use serde::Deserialize; +use tokio::time::sleep; -use super::{parse_response, DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser}; +use super::{ + make_get_request, parse_response, DirectoryGroup, DirectorySync, DirectorySyncError, + DirectoryUser, REQUEST_PAGINATION_SLOWDOWN, +}; +use crate::enterprise::directory_sync::REQUEST_TIMEOUT; #[allow(dead_code)] pub(crate) struct MicrosoftDirectorySync { @@ -15,21 +17,16 @@ pub(crate) struct MicrosoftDirectorySync { url: String, } -#[cfg(not(test))] const ACCESS_TOKEN_URL: &str = "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"; -#[cfg(not(test))] -const GROUPS_URL: &str = "https://graph.microsoft.com/v1.0/groups?$top=999"; -#[cfg(not(test))] -const USER_GROUPS: &str = "https://graph.microsoft.com/v1.0/users/{user_id}/memberOf?$top=999"; -#[cfg(not(test))] -const GROUP_MEMBERS: &str = "https://graph.microsoft.com/v1.0/groups/{group_id}/members?$select=accountEnabled,displayName,mail,otherMails&$top=999"; -const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); -const ALL_USERS_URL: &str = - "https://graph.microsoft.com/v1.0/users?$select=accountEnabled,displayName,mail,otherMails&$top=999"; -#[cfg(not(test))] +const GROUPS_URL: &str = "https://graph.microsoft.com/v1.0/groups"; +const USER_GROUPS: &str = "https://graph.microsoft.com/v1.0/users/{user_id}/memberOf"; +const GROUP_MEMBERS: &str = "https://graph.microsoft.com/v1.0/groups/{group_id}/members"; +const ALL_USERS_URL: &str = "https://graph.microsoft.com/v1.0/users"; const MICROSOFT_DEFAULT_SCOPE: &str = "https://graph.microsoft.com/.default"; -#[cfg(not(test))] const GRANT_TYPE: &str = "client_credentials"; +const MAX_RESULTS: &str = "200"; +const MAX_REQUESTS: usize = 50; +const USER_QUERY_FIELDS: &str = "accountEnabled,displayName,mail,otherMails"; #[derive(Deserialize)] struct TokenResponse { @@ -45,16 +42,53 @@ struct GroupDetails { id: String, } -#[derive(Deserialize)] +#[derive(Deserialize, Default)] struct GroupsResponse { + #[serde(rename = "@odata.nextLink")] + next_page: Option, value: Vec, } -#[derive(Debug, Serialize, Deserialize)] +impl From for Vec { + fn from(response: GroupsResponse) -> Self { + response + .value + .into_iter() + .map(|group| DirectoryGroup { + id: group.id, + name: group.display_name, + }) + .collect() + } +} + +#[derive(Debug, Serialize, Deserialize, Default)] struct GroupMembersResponse { + #[serde(rename = "@odata.nextLink")] + next_page: Option, value: Vec, } +impl From for Vec { + fn from(response: GroupMembersResponse) -> Self { + response + .value + .into_iter() + .filter_map(|user| { + if let Some(email) = user.mail { + Some(email) + } else if let Some(email) = user.other_mails.into_iter().next() { + warn!("User {} doesn't have a primary email address set, his first additional email address will be used: {email}", user.display_name); + Some(email) + } else { + warn!("User {} doesn't have any email address and will be skipped in synchronization.", user.display_name); + None + } + }) + .collect() + } +} + #[derive(Debug, Serialize, Deserialize)] struct User { #[serde(rename = "displayName")] @@ -66,27 +100,98 @@ struct User { other_mails: Vec, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] struct UsersResponse { + #[serde(rename = "@odata.nextLink")] + next_page: Option, value: Vec, } -async fn make_get_request( - url: Url, - token: String, -) -> Result { - let client = reqwest::Client::new(); - let response = client - .get(url) - .header(AUTHORIZATION, format!("Bearer {token}")) - .timeout(REQUEST_TIMEOUT) - .send() - .await?; - Ok(response) +impl From for Vec { + fn from(response: UsersResponse) -> Self { + response + .value + .into_iter() + .filter_map(|user| { + if let Some(email) = user.mail { + Some(DirectoryUser { email, active: user.account_enabled }) + } else if let Some(email) = user.other_mails.into_iter().next() { + warn!("User {} doesn't have a primary email address set, his first additional email address will be used: {email}", user.display_name); + Some(DirectoryUser { email, active: user.account_enabled }) + } else { + warn!("User {} doesn't have any email address and will be skipped in synchronization.", user.display_name); + None + } + }) + .collect() + } } -#[cfg(not(test))] impl MicrosoftDirectorySync { + pub(crate) const fn new(client_id: String, client_secret: String, url: String) -> Self { + Self { + access_token: None, + client_id, + client_secret, + url, + token_expiry: None, + } + } + + fn extract_tenant(&self) -> Result { + debug!("Extracting tenant ID from Microsoft base URL: {}", self.url); + let parts: Vec<&str> = self.url.split('/').collect(); + debug!("Split Microsoft base URL into the following parts: {parts:?}",); + let tenant_id = + parts + .get(parts.len() - 2) + .ok_or(DirectorySyncError::InvalidProviderConfiguration(format!( + "Couldn't extract tenant ID from the provided Microsoft API base URL: {}", + self.url + )))?; + debug!("Tenant ID extracted successfully: {tenant_id}",); + Ok(tenant_id.to_string()) + } + + async fn refresh_access_token(&mut self) -> Result<(), DirectorySyncError> { + debug!("Refreshing Microsoft directory sync access token."); + let token_response = self.query_access_token().await?; + let expires_in = TimeDelta::seconds(token_response.expires_in); + self.access_token = Some(token_response.token); + self.token_expiry = Some(Utc::now() + expires_in); + debug!( + "Microsoft directory sync access token refreshed, the new token expires at: {:?}", + self.token_expiry + ); + Ok(()) + } + + fn is_token_expired(&self) -> bool { + debug!( + "Checking if Microsoft directory sync token is expired, expiry date: {:?}", + self.token_expiry + ); + let result = self.token_expiry.map_or(true, |expiry| expiry < Utc::now()); + debug!("Token expiry check result: {result}"); + result + } + + async fn query_test_connection(&self) -> Result<(), DirectorySyncError> { + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let response = make_get_request( + ALL_USERS_URL, + access_token, + Some(&[("$top", "1"), ("$select", USER_QUERY_FIELDS)]), + ) + .await?; + let _result: UsersResponse = + parse_response(response, "Failed to test connection to Microsoft API.").await?; + Ok(()) + } + async fn query_access_token(&self) -> Result { debug!("Querying Microsoft directory sync access token."); let tenant_id = self.extract_tenant()?; @@ -100,6 +205,7 @@ impl MicrosoftDirectorySync { ("scope", &MICROSOFT_DEFAULT_SCOPE.to_string()), ("grant_type", &GRANT_TYPE.to_string()), ]) + .timeout(REQUEST_TIMEOUT) .send() .await?; let token_response: TokenResponse = response.json().await?; @@ -116,10 +222,29 @@ impl MicrosoftDirectorySync { .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; - let url = Url::parse(GROUPS_URL) - .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; - let response = make_get_request(url, access_token.to_string()).await?; - parse_response(response, "Failed to query all Microsoft groups.").await + let mut combined_response = GroupsResponse::default(); + let mut url = GROUPS_URL.to_string(); + let mut query = Some([("$top", MAX_RESULTS)].as_slice()); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request(&url, access_token, query).await?; + let response: GroupsResponse = + parse_response(response, "Failed to query Microsoft groups.").await?; + combined_response.value.extend(response.value); + + if let Some(next_page) = response.next_page { + url = next_page; + query = None; + debug!("Found next page of results, querying it: {url}"); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) } async fn query_user_groups(&self, user_id: &str) -> Result { @@ -133,10 +258,29 @@ impl MicrosoftDirectorySync { .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; - let url = Url::parse(&USER_GROUPS.replace("{user_id}", user_id)) - .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; - let response = make_get_request(url, access_token.to_string()).await?; - parse_response(response, "Failed to query user groups from Microsoft API.").await + let mut url = USER_GROUPS.replace("{user_id}", user_id); + let mut combined_response = GroupsResponse::default(); + let mut query = Some([("$top", MAX_RESULTS)].as_slice()); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request(&url, access_token, query).await?; + let response: GroupsResponse = + parse_response(response, "Failed to query user groups from Microsoft API.").await?; + combined_response.value.extend(response.value); + + if let Some(next_page) = response.next_page { + url = next_page; + query = None; + debug!("Found next page of results, querying it: {url}"); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) } async fn query_group_members( @@ -153,15 +297,32 @@ impl MicrosoftDirectorySync { .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; + let mut combined_response = GroupMembersResponse::default(); + let mut url = GROUP_MEMBERS.replace("{group_id}", &group.id); + let mut query = Some([("$top", MAX_RESULTS), ("$select", USER_QUERY_FIELDS)].as_slice()); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request(&url, access_token, query).await?; + let response: GroupMembersResponse = parse_response( + response, + "Failed to query group members from Microsoft API.", + ) + .await?; + combined_response.value.extend(response.value); + + if let Some(next_page) = response.next_page { + url = next_page; + query = None; + debug!("Found next page of results, querying it: {url}"); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } - let url = Url::parse(&GROUP_MEMBERS.replace("{group_id}", &group.id)) - .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; - let response = make_get_request(url, access_token.to_string()).await?; - parse_response( - response, - "Failed to query group members from Microsoft API.", - ) - .await + Ok(combined_response) } async fn query_all_users(&self) -> Result { @@ -173,90 +334,38 @@ impl MicrosoftDirectorySync { .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; - let url = Url::parse(ALL_USERS_URL) - .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; - let response = make_get_request(url, access_token.to_string()).await?; - parse_response(response, "Failed to query all users in the Microsoft API.").await - } -} - -impl MicrosoftDirectorySync { - pub(crate) const fn new(client_id: String, client_secret: String, url: String) -> Self { - Self { - access_token: None, - client_id, - client_secret, - url, - token_expiry: None, + let mut combined_response = UsersResponse::default(); + let mut url = ALL_USERS_URL.to_string(); + let mut query = Some([("$top", MAX_RESULTS), ("$select", USER_QUERY_FIELDS)].as_slice()); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request(&url, access_token, query).await?; + let response: UsersResponse = + parse_response(response, "Failed to query all users in the Microsoft API.").await?; + combined_response.value.extend(response.value); + + if let Some(next_page) = response.next_page { + url = next_page; + query = None; + debug!("Found next page of results, querying it: {url}"); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; } - } - fn extract_tenant(&self) -> Result { - debug!("Extracting tenant ID from Microsoft base URL: {}", self.url); - let parts: Vec<&str> = self.url.split('/').collect(); - debug!("Split Microsoft base URL into the following parts: {parts:?}",); - let tenant_id = - parts - .get(parts.len() - 2) - .ok_or(DirectorySyncError::InvalidProviderConfiguration(format!( - "Couldn't extract tenant ID from the provided Microsoft API base URL: {}", - self.url - )))?; - debug!("Tenant ID extracted successfully: {tenant_id}",); - Ok(tenant_id.to_string()) - } - - async fn refresh_access_token(&mut self) -> Result<(), DirectorySyncError> { - debug!("Refreshing Microsoft directory sync access token."); - let token_response = self.query_access_token().await?; - let expires_in = TimeDelta::seconds(token_response.expires_in); - self.access_token = Some(token_response.token); - self.token_expiry = Some(Utc::now() + expires_in); - debug!( - "Microsoft directory sync access token refreshed, the new token expires at: {:?}", - self.token_expiry - ); - Ok(()) - } - - fn is_token_expired(&self) -> bool { - debug!( - "Checking if Microsoft directory sync token is expired, expiry date: {:?}", - self.token_expiry - ); - let result = self.token_expiry.map_or(true, |expiry| expiry < Utc::now()); - debug!("Token expiry check result: {result}"); - result - } - - async fn query_test_connection(&self) -> Result<(), DirectorySyncError> { - let access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - let url = Url::parse(&format!("{ALL_USERS_URL}?$top=1")) - .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; - let response = make_get_request(url, access_token.to_string()).await?; - let _result: UsersResponse = - parse_response(response, "Failed to test connection to Microsoft API.").await?; - Ok(()) + Ok(combined_response) } } impl DirectorySync for MicrosoftDirectorySync { async fn get_groups(&self) -> Result, DirectorySyncError> { debug!("Querying all groups from Microsoft API."); - let groups = self - .query_groups() - .await? - .value - .into_iter() - .map(|group| DirectoryGroup { - id: group.id, - name: group.display_name, - }); + let groups = self.query_groups().await?; debug!("All groups queried successfully."); - Ok(groups.collect()) + Ok(groups.into()) } async fn get_user_groups( @@ -264,17 +373,9 @@ impl DirectorySync for MicrosoftDirectorySync { user_id: &str, ) -> Result, DirectorySyncError> { debug!("Querying groups of user: {user_id}"); - let groups = self - .query_user_groups(user_id) - .await? - .value - .into_iter() - .map(|group| DirectoryGroup { - id: group.id, - name: group.display_name, - }); + let groups = self.query_user_groups(user_id).await?; debug!("User groups queried successfully."); - Ok(groups.collect()) + Ok(groups.into()) } async fn get_group_members( @@ -282,21 +383,9 @@ impl DirectorySync for MicrosoftDirectorySync { group: &DirectoryGroup, ) -> Result, DirectorySyncError> { debug!("Querying members of group: {}", group.name); - let members = self - .query_group_members(group) - .await? - .value - .into_iter() - .filter_map(|user| { - if let Some(email) = user.mail { - Some(email) - } else { - warn!("User {} doesn't have an email address and will be skipped in synchronization.", user.display_name); - None - } - }); + let members = self.query_group_members(group).await?; debug!("Group members queried successfully."); - Ok(members.collect()) + Ok(members.into()) } async fn prepare(&mut self) -> Result<(), DirectorySyncError> { @@ -314,24 +403,9 @@ impl DirectorySync for MicrosoftDirectorySync { async fn get_all_users(&self) -> Result, DirectorySyncError> { debug!("Querying all users from Microsoft API."); - let users = self - .query_all_users() - .await? - .value - .into_iter() - .filter_map(|user| { - if let Some(email) = user.mail { - Some(DirectoryUser { email, active: user.account_enabled }) - } else if let Some(mail) = user.other_mails.first() { - warn!("User {} doesn't have a primary email address set, his first additional email address will be used: {mail}", user.display_name); - Some(DirectoryUser { email: mail.clone(), active: user.account_enabled }) - } else { - warn!("User {} doesn't have any email address and will be skipped in synchronization.", user.display_name); - None - } - }); + let users = self.query_all_users().await?; debug!("All users queried successfully."); - Ok(users.collect()) + Ok(users.into()) } async fn test_connection(&self) -> Result<(), DirectorySyncError> { @@ -342,136 +416,6 @@ impl DirectorySync for MicrosoftDirectorySync { } } -#[cfg(test)] -impl MicrosoftDirectorySync { - async fn query_user_groups( - &self, - _user_id: &str, - ) -> Result { - if self.is_token_expired() { - return Err(DirectorySyncError::AccessTokenExpired); - } - let _access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - - Ok(GroupsResponse { - value: vec![GroupDetails { - display_name: "group1".into(), - id: "1".into(), - }], - }) - } - - async fn query_groups(&self) -> Result { - if self.is_token_expired() { - return Err(DirectorySyncError::AccessTokenExpired); - } - - let _access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - - Ok(GroupsResponse { - value: vec![ - GroupDetails { - display_name: "group1".into(), - id: "1".into(), - }, - GroupDetails { - display_name: "group2".into(), - id: "2".into(), - }, - GroupDetails { - display_name: "group3".into(), - id: "3".into(), - }, - ], - }) - } - - async fn query_group_members( - &self, - _group: &DirectoryGroup, - ) -> Result { - if self.is_token_expired() { - return Err(DirectorySyncError::AccessTokenExpired); - } - let _access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - - Ok(GroupMembersResponse { - value: vec![ - User { - display_name: "testuser".into(), - mail: Some("testuser@email.com".into()), - account_enabled: true, - other_mails: vec![], - }, - User { - display_name: "testuserdisabled".into(), - mail: Some("testuserdisabled@email.com".into()), - account_enabled: false, - other_mails: vec![], - }, - User { - display_name: "testuser2".into(), - mail: Some( - "testuser2@email.com - " - .into(), - ), - account_enabled: true, - other_mails: vec![], - }, - ], - }) - } - - async fn query_access_token(&self) -> Result { - Ok(TokenResponse { - token: "test_token_refreshed".into(), - expires_in: 3600, - }) - } - - async fn query_all_users(&self) -> Result { - if self.is_token_expired() { - return Err(DirectorySyncError::AccessTokenExpired); - } - let _access_token = self - .access_token - .as_ref() - .ok_or(DirectorySyncError::AccessTokenExpired)?; - Ok(UsersResponse { - value: vec![ - User { - display_name: "testuser".into(), - mail: Some("testuser@email.com".into()), - account_enabled: true, - other_mails: vec![], - }, - User { - display_name: "testuserdisabled".into(), - mail: Some("testuserdisabled@email.com".into()), - account_enabled: false, - other_mails: vec![], - }, - User { - display_name: "testuser2".into(), - mail: Some("testuser2@email.com".into()), - account_enabled: true, - other_mails: vec![], - }, - ], - }) - } -} - #[cfg(test)] mod tests { use super::*; @@ -507,78 +451,94 @@ mod tests { dirsync.access_token = Some("test_token".into()); dirsync.token_expiry = Some(Utc::now() + TimeDelta::seconds(10000)); assert!(!dirsync.is_token_expired()); - - // no token - dirsync.access_token = Some("test_token".into()); - dirsync.token_expiry = Some(Utc::now() - TimeDelta::seconds(10000)); - dirsync.refresh_access_token().await.unwrap(); - assert!(!dirsync.is_token_expired()); - assert_eq!(dirsync.access_token, Some("test_token_refreshed".into())); - } - - #[tokio::test] - async fn test_all_users() { - let mut dirsync = MicrosoftDirectorySync::new( - "id".to_string(), - "secret".to_string(), - "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), - ); - dirsync.refresh_access_token().await.unwrap(); - - let users = dirsync.get_all_users().await.unwrap(); - - assert_eq!(users.len(), 3); - assert_eq!(users[1].email, "testuserdisabled@email.com"); - assert!(!users[1].active); } #[tokio::test] - async fn test_groups() { - let mut dirsync = MicrosoftDirectorySync::new( - "id".to_string(), - "secret".to_string(), - "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), - ); - dirsync.refresh_access_token().await.unwrap(); - - let groups = dirsync.get_groups().await.unwrap(); + async fn test_groups_parse() { + let groups_response = GroupsResponse { + next_page: None, + value: vec![ + GroupDetails { + display_name: "Group 1".to_string(), + id: "1".to_string(), + }, + GroupDetails { + display_name: "Group 2".to_string(), + id: "2".to_string(), + }, + ], + }; - assert_eq!(groups.len(), 3); + let groups: Vec = groups_response.into(); - for (i, group) in groups.iter().enumerate().take(3) { - assert_eq!(group.id, (i + 1).to_string()); - assert_eq!(group.name, format!("group{}", i + 1)); - } + assert_eq!(groups.len(), 2); + assert_eq!(groups[0].name, "Group 1"); + assert_eq!(groups[0].id, "1"); + assert_eq!(groups[1].name, "Group 2"); + assert_eq!(groups[1].id, "2"); } #[tokio::test] - async fn test_user_groups() { - let mut dirsync = MicrosoftDirectorySync::new( - "id".to_string(), - "secret".to_string(), - "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), - ); - dirsync.refresh_access_token().await.unwrap(); + async fn test_members_parse() { + let members_response = GroupMembersResponse { + next_page: None, + value: vec![ + User { + display_name: "User 1".to_string(), + mail: Some("email@email.com".to_string()), + account_enabled: true, + other_mails: vec![], + }, + User { + display_name: "User 2".to_string(), + mail: None, + account_enabled: true, + other_mails: vec!["email2@email.com".to_string()], + }, + User { + display_name: "User 3".to_string(), + mail: None, + account_enabled: true, + other_mails: vec![], + }, + ], + }; - let groups = dirsync.get_user_groups("testuser").await.unwrap(); - assert_eq!(groups.len(), 1); - assert_eq!(groups[0].id, "1"); - assert_eq!(groups[0].name, "group1"); + let members: Vec = members_response.into(); + assert_eq!(members.len(), 2); + assert_eq!(members[0], "email@email.com".to_string()); + assert_eq!(members[1], "email2@email.com".to_string()); } #[tokio::test] - async fn test_group_members() { - let mut dirsync = MicrosoftDirectorySync::new( - "id".to_string(), - "secret".to_string(), - "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), - ); - dirsync.refresh_access_token().await.unwrap(); - - let groups = dirsync.get_groups().await.unwrap(); - let members = dirsync.get_group_members(&groups[0]).await.unwrap(); + async fn test_users_parse() { + let users_response = UsersResponse { + next_page: None, + value: vec![ + User { + display_name: "User 1".to_string(), + mail: Some("email@email.com".to_string()), + account_enabled: true, + other_mails: vec![], + }, + User { + display_name: "User 2".to_string(), + mail: None, + account_enabled: true, + other_mails: vec!["email2@email.com".to_string()], + }, + User { + display_name: "User 3".to_string(), + mail: None, + account_enabled: true, + other_mails: vec![], + }, + ], + }; - assert_eq!(members.len(), 3); - assert_eq!(members[0], "testuser@email.com"); + let users: Vec = users_response.into(); + assert_eq!(users.len(), 2); + assert_eq!(users[0].email, "email@email.com".to_string()); + assert_eq!(users[1].email, "email2@email.com".to_string()); } } diff --git a/src/enterprise/directory_sync/mod.rs b/src/enterprise/directory_sync/mod.rs index 8e256c778..d69b21ae0 100644 --- a/src/enterprise/directory_sync/mod.rs +++ b/src/enterprise/directory_sync/mod.rs @@ -1,18 +1,25 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + time::Duration, +}; use paste::paste; -use sqlx::error::Error as SqlxError; -use sqlx::PgPool; +use reqwest::header::AUTHORIZATION; +use sqlx::{error::Error as SqlxError, PgPool}; use thiserror::Error; +use tokio::sync::broadcast::Sender; use super::db::models::openid_provider::{DirectorySyncTarget, OpenIdProvider}; #[cfg(not(test))] use super::is_enterprise_enabled; use crate::{ - db::{Group, Id, User}, + db::{GatewayEvent, Group, Id, User}, enterprise::db::models::openid_provider::DirectorySyncUserBehavior, }; +const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); +const REQUEST_PAGINATION_SLOWDOWN: Duration = Duration::from_millis(100); + #[derive(Debug, Error)] pub enum DirectorySyncError { #[error("Database error: {0}")] @@ -35,24 +42,31 @@ pub enum DirectorySyncError { InvalidProviderConfiguration(String), #[error("Couldn't construct URL from the given string: {0}")] InvalidUrl(String), + #[error("Failed to update network state: {0}")] + NetworkUpdateError(String), + #[error("Failed to update user state: {0}")] + UserUpdateError(String), } impl From for DirectorySyncError { fn from(err: reqwest::Error) -> Self { if err.is_decode() { - DirectorySyncError::RequestError(format!("There was an error while trying to decode provider's response, it may be malformed: {err}")) + Self::RequestError(format!("There was an error while trying to decode provider's response, it may be malformed: {err}")) } else if err.is_timeout() { - DirectorySyncError::RequestError(format!( + Self::RequestError(format!( "The request to the provider's API timed out: {err}" )) } else { - DirectorySyncError::RequestError(err.to_string()) + Self::RequestError(err.to_string()) } } } pub mod google; pub mod microsoft; +pub mod okta; +#[cfg(test)] +pub mod testprovider; #[derive(Debug, Serialize, Deserialize)] pub struct DirectoryGroup { @@ -68,6 +82,7 @@ pub struct DirectoryUser { } #[trait_variant::make(Send)] +#[trait_variant::make(Sync)] trait DirectorySync { /// Get all groups in a directory async fn get_groups(&self) -> Result, DirectorySyncError>; @@ -104,7 +119,6 @@ trait DirectorySync { /// - You implemented some way to initialize the provider client and added an initialization step in the [`DirectorySyncClient::build`] function /// - You added the provider name to the macro invocation below the macro definition /// - You've implemented your provider logic in a file called the same as your provider but lowercase, e.g. google.rs -/// // If you have time to refactor the whole thing to use boxes instead, go ahead. macro_rules! dirsync_clients { ($($variant:ident),*) => { @@ -168,7 +182,11 @@ macro_rules! dirsync_clients { }; } -dirsync_clients!(Google, Microsoft); +#[cfg(test)] +dirsync_clients!(Google, Microsoft, Okta, TestProvider); + +#[cfg(not(test))] +dirsync_clients!(Google, Microsoft, Okta); impl DirectorySyncClient { /// Builds the current directory sync client based on the current provider settings (provider name), if possible. @@ -204,6 +222,25 @@ impl DirectorySyncClient { debug!("Microsoft directory sync client created"); Ok(Self::Microsoft(client)) } + "Okta" => { + if let (Some(jwk), Some(client_id)) = ( + provider_settings.okta_private_jwk.as_ref(), + provider_settings.okta_dirsync_client_id.as_ref(), + ) { + debug!("Okta directory has all the configuration needed, proceeding with creating the sync client"); + let client = + okta::OktaDirectorySync::new(jwk, client_id, &provider_settings.base_url); + debug!("Okta directory sync client created"); + Ok(Self::Okta(client)) + } else { + Err(DirectorySyncError::InvalidProviderConfiguration( + "Okta provider is not configured correctly for Directory Sync. Okta private key or client id is missing." + .to_string(), + )) + } + } + #[cfg(test)] + "Test" => Ok(Self::TestProvider(testprovider::TestProviderDirectorySync)), _ => Err(DirectorySyncError::UnsupportedProvider( provider_settings.name.clone(), )), @@ -215,6 +252,7 @@ async fn sync_user_groups( directory_sync: &T, user: &User, pool: &PgPool, + wg_tx: &Sender, ) -> Result<(), DirectorySyncError> { info!("Syncing groups of user {} with the directory", user.email); let directory_groups = directory_sync.get_user_groups(&user.email).await?; @@ -257,6 +295,14 @@ async fn sync_user_groups( } } + user.sync_allowed_devices(&mut transaction, wg_tx) + .await + .map_err(|err| { + DirectorySyncError::NetworkUpdateError(format!( + "Failed to sync allowed devices for user {} during directory synchronization: {err}", + user.email + )) + })?; transaction.commit().await?; Ok(()) @@ -288,6 +334,7 @@ pub(crate) async fn test_directory_sync_connection( pub(crate) async fn sync_user_groups_if_configured( user: &User, pool: &PgPool, + wg_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_enterprise_enabled() { @@ -304,7 +351,7 @@ pub(crate) async fn sync_user_groups_if_configured( match DirectorySyncClient::build(pool).await { Ok(mut dir_sync) => { dir_sync.prepare().await?; - sync_user_groups(&dir_sync, user, pool).await?; + sync_user_groups(&dir_sync, user, pool, wg_tx).await?; } Err(err) => { error!("Failed to build directory sync client: {err}"); @@ -314,6 +361,7 @@ pub(crate) async fn sync_user_groups_if_configured( Ok(()) } +/// Create a group if it doesn't exist and add a user to it if they are not already a member async fn create_and_add_to_group( user: &User, group_name: &str, @@ -349,6 +397,7 @@ async fn create_and_add_to_group( async fn sync_all_users_groups( directory_sync: &T, pool: &PgPool, + wg_tx: &Sender, ) -> Result<(), DirectorySyncError> { info!("Syncing all users' groups with the directory, this may take a while..."); let directory_groups = directory_sync.get_groups().await?; @@ -435,30 +484,40 @@ async fn sync_all_users_groups( for group in groups { create_and_add_to_group(&user, group, pool).await?; } + + user.sync_allowed_devices(&mut transaction, wg_tx).await.map_err(|err| { + DirectorySyncError::NetworkUpdateError(format!( + "Failed to sync allowed devices for user {} during directory synchronization: {err}", + user.email + )) + })?; } transaction.commit().await?; - info!("Syncing all users' groups done."); Ok(()) } fn is_directory_sync_enabled(provider: Option<&OpenIdProvider>) -> bool { debug!("Checking if directory sync is enabled"); - if let Some(provider_settings) = provider { - debug!( - "Directory sync enabled state: {}", + provider.map_or_else( + || { + debug!("No openid provider found, directory sync is disabled"); + false + }, + |provider_settings| { + debug!( + "Directory sync enabled state: {}", + provider_settings.directory_sync_enabled + ); provider_settings.directory_sync_enabled - ); - provider_settings.directory_sync_enabled - } else { - debug!("No openid provider found, directory sync is disabled"); - false - } + }, + ) } async fn sync_all_users_state( directory_sync: &T, pool: &PgPool, + wg_tx: &Sender, ) -> Result<(), DirectorySyncError> { info!("Syncing all users' state with the directory, this may take a while..."); let mut transaction = pool.begin().await?; @@ -472,7 +531,6 @@ async fn sync_all_users_state( let emails = all_users .iter() - // We want to filter out the main admin user, as he shouldn't be deleted .map(|u| u.email.as_str()) .collect::>(); let missing_users = User::exclude(&mut *transaction, &emails) @@ -507,8 +565,12 @@ async fn sync_all_users_state( "Disabling user {} because they are disabled in the directory", user.email ); - user.is_active = false; - user.save(&mut *transaction).await?; + user.disable(&mut transaction, wg_tx).await.map_err(|err| { + DirectorySyncError::UserUpdateError(format!( + "Failed to disable user {} during directory synchronization: {err}", + user.email + )) + })?; } else { debug!("User {} is already disabled, skipping", user.email); } @@ -522,6 +584,7 @@ async fn sync_all_users_state( user_behavior, admin_behavior ); + // Keep the admin count to prevent deleting the last admin let mut admin_count = User::find_admins(&mut *transaction).await?.len(); for mut user in missing_users { if user.is_admin(&mut *transaction).await? { @@ -546,8 +609,12 @@ async fn sync_all_users_state( the admin behavior setting is set to disable", user.email ); - user.is_active = false; - user.save(&mut *transaction).await?; + user.disable(&mut transaction, wg_tx).await.map_err(|err| { + DirectorySyncError::UserUpdateError(format!( + "Failed to disable admin {} during directory synchronization: {err}", + user.email + )) + })?; admin_count -= 1; } else { debug!( @@ -568,7 +635,13 @@ async fn sync_all_users_state( "Deleting admin {} because they are not present in the directory", user.email ); - user.delete(&mut *transaction).await?; + user.delete_and_cleanup(&mut transaction, wg_tx) + .await + .map_err(|err| { + DirectorySyncError::UserUpdateError(format!( + "Failed to delete admin during directory synchronization: {err}" + )) + })?; admin_count -= 1; } } @@ -586,8 +659,12 @@ async fn sync_all_users_state( "Disabling user {} because they are not present in the directory and the user behavior setting is set to disable", user.email ); - user.is_active = false; - user.save(&mut *transaction).await?; + user.disable(&mut transaction, wg_tx).await.map_err(|err| { + DirectorySyncError::UserUpdateError(format!( + "Failed to disable user {} during directory synchronization: {err}", + user.email + )) + })?; } else { debug!( "User {} is already disabled in Defguard, skipping", @@ -600,7 +677,13 @@ async fn sync_all_users_state( "Deleting user {} because they are not present in the directory", user.email ); - user.delete(&mut *transaction).await?; + user.delete_and_cleanup(&mut transaction, wg_tx) + .await + .map_err(|err| { + DirectorySyncError::UserUpdateError(format!( + "Failed to delete user during directory synchronization: {err}" + )) + })?; } } } @@ -632,6 +715,7 @@ async fn sync_all_users_state( // The default inverval for the directory sync job const DIRECTORY_SYNC_INTERVAL: u64 = 60 * 10; +/// Used to inform the utility thread how often it should perform the directory sync job. See [`run_utility_thread`] for more details. pub(crate) async fn get_directory_sync_interval(pool: &PgPool) -> u64 { if let Ok(Some(provider_settings)) = OpenIdProvider::get_current(pool).await { provider_settings @@ -643,14 +727,18 @@ pub(crate) async fn get_directory_sync_interval(pool: &PgPool) -> u64 { } } -pub(crate) async fn do_directory_sync(pool: &PgPool) -> Result<(), DirectorySyncError> { +// Performs the directory sync job. This function is called by the utility thread. +pub(crate) async fn do_directory_sync( + pool: &PgPool, + wireguard_tx: &Sender, +) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_enterprise_enabled() { debug!("Enterprise is not enabled, skipping performing directory sync"); return Ok(()); } - // TODO: The settings are retrieved many times + // TODO: Reduce the amount of times those settings are retrieved in the whole directory sync process let provider = OpenIdProvider::get_current(pool).await?; if !is_directory_sync_enabled(provider.as_ref()) { @@ -672,13 +760,13 @@ pub(crate) async fn do_directory_sync(pool: &PgPool) -> Result<(), DirectorySync sync_target, DirectorySyncTarget::All | DirectorySyncTarget::Users ) { - sync_all_users_state(&dir_sync, pool).await?; + sync_all_users_state(&dir_sync, pool, wireguard_tx).await?; } if matches!( sync_target, DirectorySyncTarget::All | DirectorySyncTarget::Groups ) { - sync_all_users_groups(&dir_sync, pool).await?; + sync_all_users_groups(&dir_sync, pool, wireguard_tx).await?; } } Err(err) => { @@ -689,6 +777,9 @@ pub(crate) async fn do_directory_sync(pool: &PgPool) -> Result<(), DirectorySync Ok(()) } +// Helpers shared between the directory sync providers +// + /// Parse a reqwest response and return the JSON body if the response is OK, otherwise map an error to a DirectorySyncError::RequestError /// The context_message is used to provide more context to the error message. async fn parse_response( @@ -713,30 +804,85 @@ where } } +/// Make a GET request to the given URL with the given token and query parameters +async fn make_get_request( + url: &str, + token: &str, + query: Option<&[(&str, &str)]>, +) -> Result { + let client = reqwest::Client::new(); + let query = query.unwrap_or_default(); + let response = client + .get(url) + .query(query) + .header(AUTHORIZATION, format!("Bearer {token}")) + .timeout(REQUEST_TIMEOUT) + .send() + .await?; + Ok(response) +} + #[cfg(test)] mod test { + use std::str::FromStr; + + use ipnetwork::IpNetwork; use secrecy::ExposeSecret; + use tokio::sync::broadcast; use super::*; use crate::{ - config::DefGuardConfig, enterprise::db::models::openid_provider::DirectorySyncTarget, + config::DefGuardConfig, + db::{ + models::{device::DeviceType, settings::initialize_current_settings}, + Device, Session, SessionState, Settings, WireguardNetwork, + }, + enterprise::db::models::openid_provider::DirectorySyncTarget, SERVER_CONFIG, }; + async fn get_test_network(pool: &PgPool) -> WireguardNetwork { + WireguardNetwork::find_by_name(pool, "test") + .await + .unwrap() + .unwrap() + .pop() + .unwrap() + } + async fn make_test_provider( pool: &PgPool, user_behavior: DirectorySyncUserBehavior, admin_behavior: DirectorySyncUserBehavior, target: DirectorySyncTarget, ) -> OpenIdProvider { + Settings::init_defaults(pool).await.unwrap(); + initialize_current_settings(pool).await.unwrap(); + let current = OpenIdProvider::get_current(pool).await.unwrap(); if let Some(provider) = current { provider.delete(pool).await.unwrap(); } + WireguardNetwork::new( + "test".to_string(), + vec![IpNetwork::from_str("10.10.10.1/24").unwrap()], + 1234, + "123.123.123.123".to_string(), + None, + vec![], + false, + 32, + 32, + ) + .unwrap() + .save(pool) + .await + .unwrap(); + OpenIdProvider::new( - "Google".to_string(), + "Test".to_string(), "base_url".to_string(), "client_id".to_string(), "client_secret".to_string(), @@ -749,14 +895,16 @@ mod test { user_behavior, admin_behavior, target, + None, + None, ) .save(pool) .await .unwrap() } - async fn make_test_user(name: &str, pool: &PgPool) -> User { - User::new( + async fn make_test_user_and_device(name: &str, pool: &PgPool) -> User { + let user = User::new( name, None, "lastname", @@ -766,7 +914,25 @@ mod test { ) .save(pool) .await - .unwrap() + .unwrap(); + + let dev = Device::new( + format!("{name}-device"), + format!("{name}-key"), + user.id, + DeviceType::User, + None, + true, + ) + .save(pool) + .await + .unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + dev.add_to_all_networks(&mut transaction).await.unwrap(); + transaction.commit().await.unwrap(); + + user } async fn get_test_user(pool: &PgPool, name: &str) -> Option> { @@ -783,6 +949,7 @@ mod test { async fn test_users_state_keep_both(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -792,20 +959,23 @@ mod test { .await; let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user1 = make_test_user("user1", &pool).await; - make_test_user("user2", &pool).await; - make_test_user("testuser", &pool).await; + let user1 = make_test_user_and_device("user1", &pool).await; + make_test_user_and_device("user2", &pool).await; + make_test_user_and_device("testuser", &pool).await; make_admin(&pool, &user1).await; assert!(get_test_user(&pool, "user1").await.is_some()); assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); - sync_all_users_state(&client, &pool).await.unwrap(); + sync_all_users_state(&client, &pool, &wg_tx).await.unwrap(); assert!(get_test_user(&pool, "user1").await.is_some()); assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); + + // No events + assert!(wg_rx.try_recv().is_err()); } // Delete users, keep admins @@ -813,6 +983,7 @@ mod test { async fn test_users_state_delete_users(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -823,30 +994,37 @@ mod test { let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user1 = make_test_user("user1", &pool).await; - make_test_user("user2", &pool).await; - make_test_user("testuser", &pool).await; + let user1 = make_test_user_and_device("user1", &pool).await; + let user2 = make_test_user_and_device("user2", &pool).await; + make_test_user_and_device("testuser", &pool).await; make_admin(&pool, &user1).await; assert!(get_test_user(&pool, "user1").await.is_some()); assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); - sync_all_users_state(&client, &pool).await.unwrap(); + sync_all_users_state(&client, &pool, &wg_tx).await.unwrap(); assert!(get_test_user(&pool, "user1").await.is_some()); assert!(get_test_user(&pool, "user2").await.is_none()); assert!(get_test_user(&pool, "testuser").await.is_some()); - } - // Delete admins, keep users + let event = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + assert_eq!(dev.device.user_id, user2.id); + } else { + panic!("Expected a DeviceDeleted event"); + } + } #[sqlx::test] async fn test_users_state_delete_admins(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); User::init_admin_user(&pool, config.default_admin_password.expose_secret()) .await .unwrap(); + let _ = make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -857,17 +1035,17 @@ mod test { let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user1 = make_test_user("user1", &pool).await; - make_test_user("user2", &pool).await; - let user3 = make_test_user("user3", &pool).await; - make_test_user("testuser", &pool).await; + let user1 = make_test_user_and_device("user1", &pool).await; + make_test_user_and_device("user2", &pool).await; + let user3 = make_test_user_and_device("user3", &pool).await; + make_test_user_and_device("testuser", &pool).await; make_admin(&pool, &user1).await; make_admin(&pool, &user3).await; assert!(get_test_user(&pool, "user1").await.is_some()); assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); - sync_all_users_state(&client, &pool).await.unwrap(); + sync_all_users_state(&client, &pool, &wg_tx).await.unwrap(); assert!( get_test_user(&pool, "user1").await.is_none() @@ -879,12 +1057,21 @@ mod test { ); assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); + + // Check that we received a device deleted event for whichever admin was removed + let event = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + assert!(dev.device.user_id == user1.id || dev.device.user_id == user3.id); + } else { + panic!("Expected a DeviceDeleted event"); + } } #[sqlx::test] async fn test_users_state_delete_both(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -898,17 +1085,17 @@ mod test { let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user1 = make_test_user("user1", &pool).await; - make_test_user("user2", &pool).await; - let user3 = make_test_user("user3", &pool).await; - make_test_user("testuser", &pool).await; + let user1 = make_test_user_and_device("user1", &pool).await; + let user2 = make_test_user_and_device("user2", &pool).await; + let user3 = make_test_user_and_device("user3", &pool).await; + make_test_user_and_device("testuser", &pool).await; make_admin(&pool, &user1).await; make_admin(&pool, &user3).await; assert!(get_test_user(&pool, "user1").await.is_some()); assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); - sync_all_users_state(&client, &pool).await.unwrap(); + sync_all_users_state(&client, &pool, &wg_tx).await.unwrap(); assert!( get_test_user(&pool, "user1").await.is_none() @@ -920,12 +1107,36 @@ mod test { ); assert!(get_test_user(&pool, "user2").await.is_none()); assert!(get_test_user(&pool, "testuser").await.is_some()); + + // Check for device deletion events + let event1 = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + assert!( + dev.device.user_id == user1.id + || dev.device.user_id == user2.id + || dev.device.user_id == user3.id + ); + } else { + panic!("Expected a DeviceDeleted event"); + } + + let event2 = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + assert!( + dev.device.user_id == user1.id + || dev.device.user_id == user2.id + || dev.device.user_id == user3.id + ); + } else { + panic!("Expected a DeviceDeleted event"); + } } #[sqlx::test] async fn test_users_state_disable_users(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Disable, @@ -936,39 +1147,69 @@ mod test { let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user1 = make_test_user("user1", &pool).await; - make_test_user("user2", &pool).await; - make_test_user("testuser", &pool).await; - make_test_user("testuserdisabled", &pool).await; + let user1 = make_test_user_and_device("user1", &pool).await; + make_test_user_and_device("user2", &pool).await; + make_test_user_and_device("testuser", &pool).await; + make_test_user_and_device("testuserdisabled", &pool).await; make_admin(&pool, &user1).await; let user1 = get_test_user(&pool, "user1").await.unwrap(); let user2 = get_test_user(&pool, "user2").await.unwrap(); let testuser = get_test_user(&pool, "testuser").await.unwrap(); let testuserdisabled = get_test_user(&pool, "testuserdisabled").await.unwrap(); + let disabled_user_session = Session::new( + testuserdisabled.id, + SessionState::PasswordVerified, + "127.0.0.1".into(), + None, + ); + disabled_user_session.save(&pool).await.unwrap(); + assert!(Session::find_by_id(&pool, &disabled_user_session.id) + .await + .unwrap() + .is_some()); assert!(user1.is_active); assert!(user2.is_active); assert!(testuser.is_active); assert!(testuserdisabled.is_active); - sync_all_users_state(&client, &pool).await.unwrap(); + sync_all_users_state(&client, &pool, &wg_tx).await.unwrap(); + + // Check for device disconnection events + let event1 = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); + } else { + panic!("Expected a DeviceDisconnected event"); + } + + let event2 = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); + } else { + panic!("Expected a DeviceDisconnected event"); + } let user1 = get_test_user(&pool, "user1").await.unwrap(); let user2 = get_test_user(&pool, "user2").await.unwrap(); let testuser = get_test_user(&pool, "testuser").await.unwrap(); let testuserdisabled = get_test_user(&pool, "testuserdisabled").await.unwrap(); + assert!(Session::find_by_id(&pool, &disabled_user_session.id) + .await + .unwrap() + .is_none()); assert!(user1.is_active); assert!(!user2.is_active); assert!(testuser.is_active); assert!(!testuserdisabled.is_active); } - #[sqlx::test] async fn test_users_state_disable_admins(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); // Added mut wg_rx make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -979,11 +1220,11 @@ mod test { let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user1 = make_test_user("user1", &pool).await; - make_test_user("user2", &pool).await; - let user3 = make_test_user("user3", &pool).await; - make_test_user("testuser", &pool).await; - make_test_user("testuserdisabled", &pool).await; + let user1 = make_test_user_and_device("user1", &pool).await; + make_test_user_and_device("user2", &pool).await; + let user3 = make_test_user_and_device("user3", &pool).await; + make_test_user_and_device("testuser", &pool).await; + make_test_user_and_device("testuserdisabled", &pool).await; make_admin(&pool, &user1).await; make_admin(&pool, &user3).await; @@ -998,7 +1239,30 @@ mod test { assert!(testuser.is_active); assert!(testuserdisabled.is_active); - sync_all_users_state(&client, &pool).await.unwrap(); + sync_all_users_state(&client, &pool, &wg_tx).await.unwrap(); + + // Check for device disconnection events + let event1 = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + assert!( + dev.device.user_id == user1.id + || dev.device.user_id == user3.id + || dev.device.user_id == testuserdisabled.id + ); + } else { + panic!("Expected a DeviceDisconnected event"); + } + + let event2 = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + assert!( + dev.device.user_id == user1.id + || dev.device.user_id == user3.id + || dev.device.user_id == testuserdisabled.id + ); + } else { + panic!("Expected a DeviceDisconnected event"); + } let user1 = get_test_user(&pool, "user1").await.unwrap(); let user2 = get_test_user(&pool, "user2").await.unwrap(); @@ -1017,6 +1281,7 @@ mod test { async fn test_users_groups(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -1027,10 +1292,10 @@ mod test { let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - make_test_user("testuser", &pool).await; - make_test_user("testuser2", &pool).await; - make_test_user("testuserdisabled", &pool).await; - sync_all_users_groups(&client, &pool).await.unwrap(); + make_test_user_and_device("testuser", &pool).await; + make_test_user_and_device("testuser2", &pool).await; + make_test_user_and_device("testuserdisabled", &pool).await; + sync_all_users_groups(&client, &pool, &wg_tx).await.unwrap(); let mut groups = Group::all(&pool).await.unwrap(); @@ -1067,6 +1332,7 @@ mod test { async fn test_sync_user_groups(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -1076,10 +1342,12 @@ mod test { .await; let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user = make_test_user("testuser", &pool).await; + let user = make_test_user_and_device("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - sync_user_groups_if_configured(&user, &pool).await.unwrap(); + sync_user_groups_if_configured(&user, &pool, &wg_tx) + .await + .unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 1); let group = Group::find_by_name(&pool, "group1").await.unwrap().unwrap(); @@ -1090,6 +1358,7 @@ mod test { async fn test_sync_target_users(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -1099,10 +1368,10 @@ mod test { .await; let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user = make_test_user("testuser", &pool).await; + let user = make_test_user_and_device("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool).await.unwrap(); + do_directory_sync(&pool, &wg_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); } @@ -1111,6 +1380,7 @@ mod test { async fn test_sync_target_all(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -1118,23 +1388,52 @@ mod test { DirectorySyncTarget::All, ) .await; + let network = get_test_network(&pool).await; + let mut transaction = pool.begin().await.unwrap(); + let group = Group::new("group1".to_string()) + .save(&mut *transaction) + .await + .unwrap(); + network + .set_allowed_groups(&mut transaction, vec![group.name]) + .await + .unwrap(); + transaction.commit().await.unwrap(); let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user = make_test_user("testuser", &pool).await; - make_test_user("user2", &pool).await; + let user = make_test_user_and_device("testuser", &pool).await; + let user2_pre_sync = make_test_user_and_device("user2", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool).await.unwrap(); + do_directory_sync(&pool, &wg_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 3); let user2 = get_test_user(&pool, "user2").await; assert!(user2.is_none()); + let mut transaction = pool.begin().await.unwrap(); + user.sync_allowed_devices(&mut transaction, &wg_tx) + .await + .unwrap(); + transaction.commit().await.unwrap(); + let event = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + assert_eq!(dev.device.user_id, user2_pre_sync.id); + } else { + panic!("Expected a DeviceDeleted event"); + } + let event = wg_rx.try_recv(); + if let Ok(GatewayEvent::DeviceCreated(dev)) = event { + assert_eq!(dev.device.user_id, user.id); + } else { + panic!("Expected a DeviceDeleted event"); + } } #[sqlx::test] async fn test_sync_target_groups(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -1144,11 +1443,11 @@ mod test { .await; let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); - let user = make_test_user("testuser", &pool).await; - make_test_user("user2", &pool).await; + let user = make_test_user_and_device("testuser", &pool).await; + make_test_user_and_device("user2", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool).await.unwrap(); + do_directory_sync(&pool, &wg_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 3); let user2 = get_test_user(&pool, "user2").await; @@ -1159,6 +1458,7 @@ mod test { async fn test_sync_unassign_last_admin_group(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -1170,38 +1470,39 @@ mod test { client.prepare().await.unwrap(); // Make one admin and check if he's deleted - let user = make_test_user("testuser", &pool).await; + let user = make_test_user_and_device("testuser", &pool).await; let admin_grp = Group::find_by_name(&pool, "admin").await.unwrap().unwrap(); user.add_to_group(&pool, &admin_grp).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 1); assert!(user.is_admin(&pool).await.unwrap()); - do_directory_sync(&pool).await.unwrap(); + do_directory_sync(&pool, &wg_tx).await.unwrap(); // He should still be an admin as it's the last one assert!(user.is_admin(&pool).await.unwrap()); // Make another admin and check if one of them is deleted - let user2 = make_test_user("testuser2", &pool).await; + let user2 = make_test_user_and_device("testuser2", &pool).await; user2.add_to_group(&pool, &admin_grp).await.unwrap(); - do_directory_sync(&pool).await.unwrap(); + do_directory_sync(&pool, &wg_tx).await.unwrap(); let admins = User::find_admins(&pool).await.unwrap(); // There should be only one admin left assert_eq!(admins.len(), 1); - let defguard_user = make_test_user("defguard", &pool).await; + let defguard_user = make_test_user_and_device("defguard", &pool).await; make_admin(&pool, &defguard_user).await; - do_directory_sync(&pool).await.unwrap(); + do_directory_sync(&pool, &wg_tx).await.unwrap(); } #[sqlx::test] async fn test_sync_delete_last_admin_user(pool: PgPool) { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -1213,11 +1514,11 @@ mod test { client.prepare().await.unwrap(); // a user that's not in the directory - let defguard_user = make_test_user("defguard", &pool).await; + let defguard_user = make_test_user_and_device("defguard", &pool).await; make_admin(&pool, &defguard_user).await; assert!(defguard_user.is_admin(&pool).await.unwrap()); - do_directory_sync(&pool).await.unwrap(); + do_directory_sync(&pool, &wg_tx).await.unwrap(); // The user should still be an admin assert!(defguard_user.is_admin(&pool).await.unwrap()); @@ -1229,7 +1530,7 @@ mod test { .await .unwrap(); - do_directory_sync(&pool).await.unwrap(); + do_directory_sync(&pool, &wg_tx).await.unwrap(); let user = User::find_by_username(&pool, "defguard").await.unwrap(); assert!(user.is_none()); } diff --git a/src/enterprise/directory_sync/okta.rs b/src/enterprise/directory_sync/okta.rs new file mode 100644 index 000000000..91bc43c8d --- /dev/null +++ b/src/enterprise/directory_sync/okta.rs @@ -0,0 +1,551 @@ +use std::str::FromStr; + +use chrono::{DateTime, TimeDelta, Utc}; +use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; +use parse_link_header::parse_with_rel; +use tokio::time::sleep; + +use super::{ + parse_response, DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser, + REQUEST_PAGINATION_SLOWDOWN, +}; +use crate::enterprise::directory_sync::make_get_request; + +// Okta suggests using the maximum limit of 200 for the number of results per page. +// If this is an issue, we would need to add resource pagination. +const ACCESS_TOKEN_URL: &str = "{BASE_URL}/oauth2/v1/token"; +const GROUPS_URL: &str = "{BASE_URL}/api/v1/groups"; +const GRANT_TYPE: &str = "client_credentials"; +const CLIENT_ASSERTION_TYPE: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; +const TOKEN_SCOPE: &str = "okta.users.read okta.groups.read"; +const ALL_USERS_URL: &str = "{BASE_URL}/api/v1/users"; +const GROUP_MEMBERS: &str = "{BASE_URL}/api/v1/groups/{GROUP_ID}/users"; +const USER_GROUPS: &str = "{BASE_URL}/api/v1/users/{USER_ID}/groups"; +const MAX_RESULTS: &str = "200"; +const MAX_REQUESTS: usize = 50; + +pub fn extract_next_link( + link_header: Option<&String>, +) -> Result, DirectorySyncError> { + if let Some(header) = link_header { + let mut res = parse_with_rel(header).map_err(|e| { + DirectorySyncError::InvalidUrl(format!("Failed to parse link header: {e:?}")) + })?; + Ok(res.remove("next").map(|x| x.raw_uri)) + } else { + Ok(None) + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct Claims { + iss: String, + aud: String, + sub: String, + exp: i64, + iat: i64, +} + +impl Claims { + #[must_use] + fn new(client_id: &str, base_url: &str) -> Self { + let now = Utc::now(); + let now_timestamp = now.timestamp(); + let exp = now_timestamp + 3600; + Self { + iss: client_id.into(), + aud: ACCESS_TOKEN_URL.replace("{BASE_URL}", base_url), + sub: client_id.into(), + exp, + iat: now_timestamp, + } + } +} + +#[allow(dead_code)] +pub struct OktaDirectorySync { + access_token: Option, + token_expiry: Option>, + jwk_private_key: String, + client_id: String, + base_url: String, +} + +/// Okta Directory API responses + +#[derive(Debug, Deserialize)] +pub struct AccessTokenResponse { + #[serde(rename = "access_token")] + token: String, + expires_in: i64, +} + +#[derive(Debug, Deserialize)] +struct UserProfile { + email: String, +} + +#[derive(Debug, Deserialize)] +struct User { + status: String, + profile: UserProfile, +} + +impl From for DirectoryUser { + fn from(val: User) -> Self { + Self { + email: val.profile.email, + active: ACTIVE_STATUS.contains(&val.status.as_str()), + } + } +} + +#[derive(Debug, Deserialize)] +struct GroupProfile { + name: String, +} + +#[derive(Debug, Deserialize)] +struct Group { + id: String, + profile: GroupProfile, +} + +impl From for DirectoryGroup { + fn from(val: Group) -> Self { + Self { + id: val.id, + name: val.profile.name, + } + } +} + +// The status may be: +// "ACTIVE" "DEPROVISIONED" "LOCKED_OUT" "PASSWORD_EXPIRED" "PROVISIONED" "RECOVERY" "STAGED" "SUSPENDED" +// We currently consider only ACTIVE users as active. Change this if needed. +const ACTIVE_STATUS: [&str; 1] = ["ACTIVE"]; + +impl OktaDirectorySync { + #[must_use] + pub fn new(private_key: &str, client_id: &str, base_url: &str) -> Self { + // Remove the trailing slash just to make sure + let trimmed = base_url.trim_end_matches('/'); + Self { + client_id: client_id.to_string(), + jwk_private_key: private_key.to_string(), + base_url: trimmed.to_string(), + access_token: None, + token_expiry: None, + } + } + + pub async fn refresh_access_token(&mut self) -> Result<(), DirectorySyncError> { + debug!("Refreshing Okta directory sync access token"); + let token_response = self.query_access_token().await?; + let expires_in = TimeDelta::seconds(token_response.expires_in); + debug!( + "Access token refreshed, the new token expires in {} seconds", + token_response.expires_in + ); + self.access_token = Some(token_response.token); + self.token_expiry = Some(Utc::now() + expires_in); + Ok(()) + } + + pub fn is_token_expired(&self) -> bool { + debug!("Checking if Okta directory sync token is expired"); + // No token = expired token + let result = self.token_expiry.map_or(true, |expiry| expiry < Utc::now()); + debug!("Token is expired: {}", result); + result + } + + async fn query_test_connection(&self) -> Result<(), DirectorySyncError> { + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let response = make_get_request( + &ALL_USERS_URL.replace("{BASE_URL}", &self.base_url), + access_token, + Some(&[("limit", "1")]), + ) + .await?; + let _result: Vec = + parse_response(response, "Failed to test connection to Okta API.").await?; + Ok(()) + } + + async fn query_user_groups(&self, user_id: &str) -> Result, DirectorySyncError> { + if self.is_token_expired() { + return Err(DirectorySyncError::AccessTokenExpired); + } + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let mut url = USER_GROUPS + .replace("{BASE_URL}", &self.base_url) + .replace("{USER_ID}", user_id); + let mut combined_response: Vec = Vec::new(); + let mut query = Some([("limit", MAX_RESULTS)].as_slice()); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request(&url, access_token, query).await?; + let link_header = { + let links = response + .headers() + .get_all("link") + .iter() + .filter_map(|link| link.to_str().ok()) + .collect::>(); + + (!links.is_empty()).then(|| links.join(", ")) + }; + let result: Vec = + parse_response(response, "Failed to query user groups in the Okta API.").await?; + combined_response.extend(result); + + if let Some(next_link) = extract_next_link(link_header.as_ref())? { + url = next_link; + // Query is already appended to the URL we received from the link header + query = None; + debug!("Found next page of results, querying it: {url}"); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) + } + + async fn query_groups(&self) -> Result, DirectorySyncError> { + if self.is_token_expired() { + return Err(DirectorySyncError::AccessTokenExpired); + } + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let mut url = GROUPS_URL.replace("{BASE_URL}", &self.base_url); + let mut combined_response: Vec = Vec::new(); + let mut query = Some([("limit", MAX_RESULTS)].as_slice()); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request(&url, access_token, query).await?; + let link_header = { + let links = response + .headers() + .get_all("link") + .iter() + .filter_map(|link| link.to_str().ok()) + .collect::>(); + + (!links.is_empty()).then(|| links.join(", ")) + }; + let result: Vec = + parse_response(response, "Failed to query groups in the Okta API.").await?; + combined_response.extend(result); + + if let Some(next_link) = extract_next_link(link_header.as_ref())? { + url = next_link; + // Query is already appended to the URL we received from the link header + query = None; + debug!("Found next page of results, querying it: {url}"); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) + } + + async fn query_group_members( + &self, + group: &DirectoryGroup, + ) -> Result, DirectorySyncError> { + if self.is_token_expired() { + return Err(DirectorySyncError::AccessTokenExpired); + } + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let mut url = GROUP_MEMBERS + .replace("{BASE_URL}", &self.base_url) + .replace("{GROUP_ID}", &group.id); + let mut combined_response: Vec = Vec::new(); + let mut query = Some([("limit", MAX_RESULTS)].as_slice()); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request(&url, access_token, query).await?; + let link_header = { + let links = response + .headers() + .get_all("link") + .iter() + .filter_map(|link| link.to_str().ok()) + .collect::>(); + + (!links.is_empty()).then(|| links.join(", ")) + }; + let result: Vec = + parse_response(response, "Failed to query group members in the Okta API.").await?; + combined_response.extend(result); + + if let Some(next_link) = extract_next_link(link_header.as_ref())? { + url = next_link; + // Query is already appended to the URL we received from the link header + query = None; + debug!("Found next page of results, querying it: {url}"); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) + } + + fn build_token(&self) -> Result { + debug!("Building Okta directory sync auth token"); + let claims = Claims::new(&self.client_id, &self.base_url); + debug!("Using the following token claims: {:?}", claims); + // Users provide a JWK format private key. The jsonwebtoken library currently doesn't support + // converting JWK to PEM or encoding key so the jsonwebkey library is used to convert the key + // to a PEM format. + debug!("Building Okta directory sync encoding key"); + let jwk = jsonwebkey::JsonWebKey::from_str(&self.jwk_private_key) + .map_err(|e| DirectorySyncError::InvalidProviderConfiguration(e.to_string()))?; + let kid = jwk + .key_id + .ok_or(DirectorySyncError::InvalidProviderConfiguration( + "Missing key id in the provided JSON key".to_string(), + ))?; + let encoding_key_pem = jwk + .key + .try_to_pem() + .map_err(|e| DirectorySyncError::InvalidProviderConfiguration(e.to_string()))?; + let key = EncodingKey::from_rsa_pem(encoding_key_pem.as_bytes())?; + debug!("Successfully built Okta directory sync encoding key for encoding the auth token"); + let mut header = Header::new(Algorithm::RS256); + header.kid = Some(kid); + let token = encode(&header, &claims, &key)?; + debug!("Successfully built Okta directory sync auth token"); + Ok(token) + } + + async fn query_access_token(&self) -> Result { + let token = self.build_token()?; + let client = reqwest::Client::new(); + let response = client + .post(ACCESS_TOKEN_URL.replace("{BASE_URL}", &self.base_url)) + .form(&[ + ("grant_type", GRANT_TYPE), + ("client_assertion_type", CLIENT_ASSERTION_TYPE), + ("client_assertion", &token), + ("scope", TOKEN_SCOPE), + ]) + .send() + .await?; + parse_response(response, "Failed to get access token from Okta API.").await + } + + async fn query_all_users(&self) -> Result, DirectorySyncError> { + if self.is_token_expired() { + return Err(DirectorySyncError::AccessTokenExpired); + } + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let mut url = ALL_USERS_URL.replace("{BASE_URL}", &self.base_url); + let mut query = Some([("limit", MAX_RESULTS)].as_slice()); + let mut combined_response: Vec = Vec::new(); + + for _ in 0..MAX_REQUESTS { + let response = make_get_request(&url, access_token, query).await?; + let link_header = { + let links = response + .headers() + .get_all("link") + .iter() + .filter_map(|link| link.to_str().ok()) + .collect::>(); + + (!links.is_empty()).then(|| links.join(", ")) + }; + let result: Vec = + parse_response(response, "Failed to query all users in the Okta API.").await?; + combined_response.extend(result); + if let Some(next_link) = extract_next_link(link_header.as_ref())? { + url = next_link; + // Query is already appended to the URL we received from the link header + query = None; + debug!("Found next page of results, querying it: {url}"); + } else { + debug!("No more pages of results found, finishing query."); + break; + } + + sleep(REQUEST_PAGINATION_SLOWDOWN).await; + } + + Ok(combined_response) + } +} + +impl DirectorySync for OktaDirectorySync { + async fn get_groups(&self) -> Result, DirectorySyncError> { + debug!("Getting all groups"); + let response = self.query_groups().await?; + debug!("Got all groups response"); + Ok(response.into_iter().map(Into::into).collect()) + } + + async fn get_user_groups( + &self, + user_id: &str, + ) -> Result, DirectorySyncError> { + debug!("Getting groups of user {user_id}"); + let response = self.query_user_groups(user_id).await?; + debug!("Got groups response for user {user_id}"); + Ok(response.into_iter().map(Into::into).collect()) + } + + async fn get_group_members( + &self, + group: &DirectoryGroup, + ) -> Result, DirectorySyncError> { + debug!("Getting group members of group {}", group.name); + let response = self.query_group_members(group).await?; + debug!("Got group members response for group {}", group.name); + Ok(response + .into_iter() + .map(|user| user.profile.email) + .collect()) + } + + async fn prepare(&mut self) -> Result<(), DirectorySyncError> { + debug!("Preparing Okta directory sync..."); + if self.is_token_expired() { + debug!("Access token is expired, refreshing."); + self.refresh_access_token().await?; + debug!("Access token refreshed."); + } else { + debug!("Access token is still valid, skipping refresh."); + } + debug!("Okta directory sync prepared."); + Ok(()) + } + + async fn get_all_users(&self) -> Result, DirectorySyncError> { + debug!("Getting all users"); + let response: Vec = self.query_all_users().await?; + debug!("Got all users response"); + Ok(response.into_iter().map(Into::into).collect()) + } + + async fn test_connection(&self) -> Result<(), DirectorySyncError> { + debug!("Testing connection to Okta API."); + self.query_test_connection().await?; + info!("Successfully tested connection to Okta API, connection is working."); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_token() { + let mut dirsync = + OktaDirectorySync::new("private_key", "client_id", "https://trial-0000000.okta.com"); + + // no token + assert!(dirsync.is_token_expired()); + + // expired token + dirsync.access_token = Some("test_token".into()); + dirsync.token_expiry = Some(Utc::now() - TimeDelta::seconds(10000)); + assert!(dirsync.is_token_expired()); + + // valid token + dirsync.access_token = Some("test_token".into()); + dirsync.token_expiry = Some(Utc::now() + TimeDelta::seconds(10000)); + assert!(!dirsync.is_token_expired()); + } + + #[tokio::test] + async fn test_header() { + let link_header = + "; rel=\"next\"" + .to_string(); + let next_link = extract_next_link(Some(&link_header)).unwrap(); + assert_eq!( + next_link, + Some("https://trial-0000000.okta.com/api/v1/users?after=4&limit=200".to_string()) + ); + + let next_link = extract_next_link(None).unwrap(); + assert_eq!(next_link, None); + + let link_header = "invalid".to_string(); + let next_link = extract_next_link(Some(&link_header)); + assert!(next_link.is_err()); + + let link_header = "; rel=\"next\", ; rel=\"prev\"".to_string(); + let next_link = extract_next_link(Some(&link_header)).unwrap(); + assert_eq!( + next_link, + Some("https://trial-0000000.okta.com/api/v1/users?after=4&limit=200".to_string()) + ); + } + + #[tokio::test] + async fn test_group_parse() { + let group = Group { + id: "test_id".to_string(), + profile: GroupProfile { + name: "test_name".to_string(), + }, + }; + let dir_group: DirectoryGroup = group.into(); + assert_eq!(dir_group.id, "test_id"); + assert_eq!(dir_group.name, "test_name"); + } + + #[tokio::test] + async fn test_user_parse() { + let user = User { + status: "ACTIVE".to_string(), + profile: UserProfile { + email: "test_email".to_string(), + }, + }; + + let dir_user: DirectoryUser = user.into(); + assert_eq!(dir_user.email, "test_email"); + assert!(dir_user.active); + + let user = User { + status: "INACTIVE".to_string(), + profile: UserProfile { + email: "test_email".to_string(), + }, + }; + + let dir_user: DirectoryUser = user.into(); + assert_eq!(dir_user.email, "test_email"); + assert!(!dir_user.active); + } +} diff --git a/src/enterprise/directory_sync/testprovider.rs b/src/enterprise/directory_sync/testprovider.rs new file mode 100644 index 000000000..108388958 --- /dev/null +++ b/src/enterprise/directory_sync/testprovider.rs @@ -0,0 +1,69 @@ +use super::{DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser}; + +#[allow(dead_code)] +pub(crate) struct TestProviderDirectorySync; + +impl DirectorySync for TestProviderDirectorySync { + async fn get_groups(&self) -> Result, DirectorySyncError> { + Ok(vec![ + DirectoryGroup { + id: "1".into(), + name: "group1".into(), + }, + DirectoryGroup { + id: "2".into(), + name: "group2".into(), + }, + DirectoryGroup { + id: "3".into(), + name: "group3".into(), + }, + ]) + } + + async fn get_user_groups( + &self, + _user_id: &str, + ) -> Result, DirectorySyncError> { + Ok(vec![DirectoryGroup { + id: "1".into(), + name: "group1".into(), + }]) + } + + async fn get_group_members( + &self, + _group: &DirectoryGroup, + ) -> Result, DirectorySyncError> { + Ok(vec![ + "testuser@email.com".into(), + "testuserdisabled@email.com".into(), + "testuser2@email.com".into(), + ]) + } + + async fn prepare(&mut self) -> Result<(), DirectorySyncError> { + Ok(()) + } + + async fn get_all_users(&self) -> Result, DirectorySyncError> { + Ok(vec![ + DirectoryUser { + email: "testuser@email.com".into(), + active: true, + }, + DirectoryUser { + email: "testuserdisabled@email.com".into(), + active: false, + }, + DirectoryUser { + email: "testuser2@email.com".into(), + active: true, + }, + ]) + } + + async fn test_connection(&self) -> Result<(), DirectorySyncError> { + Ok(()) + } +} diff --git a/src/enterprise/handlers/openid_login.rs b/src/enterprise/handlers/openid_login.rs index 0a5e90cd5..ad0a7ef4f 100644 --- a/src/enterprise/handlers/openid_login.rs +++ b/src/enterprise/handlers/openid_login.rs @@ -486,7 +486,9 @@ pub(crate) async fn auth_callback( } if let Some(user_info) = user_info { - if let Err(err) = sync_user_groups_if_configured(&user, &appstate.pool).await { + if let Err(err) = + sync_user_groups_if_configured(&user, &appstate.pool, &appstate.wireguard_tx).await + { error!( "Failed to sync user groups for user {} with the directory while the user was logging in through an external provider: {err:?}", user.username diff --git a/src/enterprise/handlers/openid_providers.rs b/src/enterprise/handlers/openid_providers.rs index b0bee333f..e0071705e 100644 --- a/src/enterprise/handlers/openid_providers.rs +++ b/src/enterprise/handlers/openid_providers.rs @@ -33,6 +33,8 @@ pub struct AddProviderData { pub directory_sync_admin_behavior: String, pub directory_sync_target: String, pub create_account: bool, + pub okta_private_jwk: Option, + pub okta_dirsync_client_id: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -77,6 +79,31 @@ pub async fn add_openid_provider( None => None, }; + let okta_private_jwk = match &provider_data.okta_private_jwk { + Some(key) => { + if serde_json::from_str::(key).is_ok() { + debug!( + "User {} provided a valid JWK private key for provider's Okta directory sync, using it", + session.user.username + ); + provider_data.okta_private_jwk.clone() + } else if let Some(provider) = ¤t_provider { + debug!( + "User {} did not provide a valid JWK private key for provider's Okta directory sync or the key did not change, using the existing key", + session.user.username + ); + provider.okta_private_jwk.clone() + } else { + warn!( + "User {} did not provide a valid JWK private key for provider's Okta directory sync", + session.user.username + ); + None + } + } + None => None, + }; + let mut settings = Settings::get_current_settings(); settings.openid_create_account = provider_data.create_account; update_current_settings(&appstate.pool, settings).await?; @@ -96,6 +123,8 @@ pub async fn add_openid_provider( provider_data.directory_sync_user_behavior.into(), provider_data.directory_sync_admin_behavior.into(), provider_data.directory_sync_target.into(), + okta_private_jwk, + provider_data.okta_dirsync_client_id, ) .upsert(&appstate.pool) .await?; @@ -125,6 +154,7 @@ pub async fn get_current_openid_provider( Some(mut provider) => { // Get rid of it, it should stay on the backend only. provider.google_service_account_key = None; + provider.okta_private_jwk = None; Ok(ApiResponse { json: json!({ "provider": json!(provider), diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs index 434f01bc0..fcbdd44c1 100644 --- a/src/grpc/desktop_client_mfa.rs +++ b/src/grpc/desktop_client_mfa.rs @@ -49,6 +49,7 @@ impl ClientMfaServer { sessions: HashMap::new(), } } + fn generate_token(pubkey: &str) -> Result { Claims::new( ClaimsType::DesktopClient, diff --git a/src/grpc/gateway.rs b/src/grpc/gateway.rs index 691634870..d2da48c8d 100644 --- a/src/grpc/gateway.rs +++ b/src/grpc/gateway.rs @@ -27,6 +27,26 @@ use crate::{ tonic::include_proto!("gateway"); +/// Sends given `GatewayEvent` to be handled by gateway GRPC server +/// +/// If you want to use it inside the API context, use [`crate::AppState::send_wireguard_event`] instead +pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { + debug!("Sending the following WireGuard event to the gateway: {event:?}"); + if let Err(err) = wg_tx.send(event) { + error!("Error sending WireGuard event {err}"); + } +} + +/// Sends multiple events to be handled by gateway GRPC server +/// +/// If you want to use it inside the API context, use [`crate::AppState::send_multiple_wireguard_events`] instead +pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { + debug!("Sending {} wireguard events", events.len()); + for event in events { + send_wireguard_event(event, wg_tx); + } +} + pub struct GatewayServer { pool: PgPool, state: Arc>, diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 2ecfc79ce..517914733 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -442,7 +442,7 @@ pub async fn run_grpc_bidi_stream( let enrollment_server = EnrollmentServer::new(pool.clone(), wireguard_tx.clone(), mail_tx.clone()); let password_reset_server = PasswordResetServer::new(pool.clone(), mail_tx.clone()); - let mut client_mfa_server = ClientMfaServer::new(pool.clone(), mail_tx, wireguard_tx); + let mut client_mfa_server = ClientMfaServer::new(pool.clone(), mail_tx, wireguard_tx.clone()); let polling_server = PollingServer::new(pool.clone()); let endpoint = Endpoint::from_shared(config.proxy_url.as_deref().unwrap())?; @@ -677,8 +677,12 @@ pub async fn run_grpc_bidi_stream( { Ok(user) => { user.clear_unused_enrollment_tokens(&pool).await?; - if let Err(err) = - sync_user_groups_if_configured(&user, &pool).await + if let Err(err) = sync_user_groups_if_configured( + &user, + &pool, + &wireguard_tx, + ) + .await { error!( "Failed to sync user groups for user {} with the directory while the user was logging in through an external provider: {err:?}", diff --git a/src/handlers/openid_flow.rs b/src/handlers/openid_flow.rs index 19cad215a..060cdf675 100644 --- a/src/handlers/openid_flow.rs +++ b/src/handlers/openid_flow.rs @@ -144,6 +144,7 @@ struct FieldResponseTypes(Vec); impl Deref for FieldResponseTypes { type Target = Vec; + fn deref(&self) -> &Self::Target { &self.0 } diff --git a/src/handlers/user.rs b/src/handlers/user.rs index 4a2221491..d44c1549a 100644 --- a/src/handlers/user.rs +++ b/src/handlers/user.rs @@ -12,16 +12,12 @@ use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, db::{ - models::{ - device::DeviceInfo, - enrollment::{Token, PASSWORD_RESET_TOKEN_TYPE}, - }, - AppEvent, GatewayEvent, OAuth2AuthorizedApp, User, UserDetails, UserInfo, WebAuthn, - WireguardNetwork, + models::enrollment::{Token, PASSWORD_RESET_TOKEN_TYPE}, + AppEvent, OAuth2AuthorizedApp, User, UserDetails, UserInfo, WebAuthn, }, enterprise::{db::models::enterprise_settings::EnterpriseSettings, limits::update_counts}, error::WebError, - ldap::utils::{ldap_add_user, ldap_change_password, ldap_delete_user, ldap_modify_user}, + ldap::utils::{ldap_add_user, ldap_change_password, ldap_modify_user}, mail::Mail, server_config, templates, }; @@ -660,12 +656,8 @@ pub async fn modify_user( "User {} changed {username} groups or status, syncing allowed network devices.", session.user.username ); - let networks = WireguardNetwork::all(&mut *transaction).await?; - for network in networks { - let gateway_events = network.sync_allowed_devices(&mut transaction, None).await?; - appstate.send_multiple_wireguard_events(gateway_events); - } - info!("Allowed network devices of {username} synced"); + user.sync_allowed_devices(&mut transaction, &appstate.wireguard_tx) + .await?; }; user_info.into_user_all_fields(&mut user)?; } else { @@ -726,18 +718,9 @@ pub async fn delete_user( session.user.username ); let mut transaction = appstate.pool.begin().await?; - let devices = user.devices(&mut *transaction).await?; - let mut events = Vec::new(); - for device in devices { - events.push(GatewayEvent::DeviceDeleted( - DeviceInfo::from_device(&mut *transaction, device).await?, - )); - } - appstate.send_multiple_wireguard_events(events); - debug!("Devices of user {username} purged from networks."); + user.delete_and_cleanup(&mut transaction, &appstate.wireguard_tx) + .await?; - user.delete(&mut *transaction).await?; - let _result = ldap_delete_user(&username).await; appstate.trigger_action(AppEvent::UserDeleted(username.clone())); transaction.commit().await?; update_counts(&appstate.pool).await?; diff --git a/src/utility_thread.rs b/src/utility_thread.rs index c136c376f..b2094a388 100644 --- a/src/utility_thread.rs +++ b/src/utility_thread.rs @@ -1,9 +1,13 @@ use std::time::Duration; use sqlx::PgPool; -use tokio::time::{sleep, Instant}; +use tokio::{ + sync::broadcast::Sender, + time::{sleep, Instant}, +}; use crate::{ + db::GatewayEvent, enterprise::{ directory_sync::{do_directory_sync, get_directory_sync_interval}, limits::do_count_update, @@ -15,13 +19,16 @@ const UTILITY_THREAD_MAIN_SLEEP_TIME: u64 = 5; const COUNT_UPDATE_INTERVAL: u64 = 60 * 60; const UPDATES_CHECK_INTERVAL: u64 = 60 * 60 * 6; -pub async fn run_utility_thread(pool: &PgPool) -> Result<(), anyhow::Error> { +pub async fn run_utility_thread( + pool: &PgPool, + wireguard_tx: Sender, +) -> Result<(), anyhow::Error> { let mut last_count_update = Instant::now(); let mut last_directory_sync = Instant::now(); let mut last_updates_check = Instant::now(); let directory_sync_task = || async { - if let Err(e) = do_directory_sync(pool).await { + if let Err(e) = do_directory_sync(pool, &wireguard_tx).await { error!("There was an error while performing directory sync job: {e:?}",); } }; diff --git a/tests/auth.rs b/tests/auth.rs index d31eea478..16f127169 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -1,6 +1,6 @@ pub mod common; -use std::{str::FromStr, time::SystemTime}; +use std::time::SystemTime; use chrono::NaiveDateTime; use claims::{assert_err, assert_ok}; @@ -11,7 +11,6 @@ use defguard::{ models::settings::update_current_settings, MFAInfo, MFAMethod, Settings, User, UserDetails, }, handlers::{Auth, AuthCode, AuthResponse, AuthTotp}, - secret::SecretStringWrapper, }; use reqwest::{header::USER_AGENT, StatusCode}; use serde::Deserialize; @@ -301,8 +300,6 @@ async fn test_email_mfa() { let mut settings = Settings::get_current_settings(); settings.smtp_server = Some("smtp_server".into()); settings.smtp_port = Some(587); - settings.smtp_user = Some("dummy_user".into()); - settings.smtp_password = Some(SecretStringWrapper::from_str("dummy_password").unwrap()); settings.smtp_sender = Some("smtp@sender.pl".into()); update_current_settings(&pool, settings).await.unwrap(); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 5fed197a6..ce030f9db 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -23,10 +23,12 @@ use reqwest::{header::HeaderName, StatusCode, Url}; use secrecy::ExposeSecret; use serde_json::{json, Value}; use sqlx::{postgres::PgConnectOptions, query, types::Uuid, PgPool}; -use tokio::net::TcpListener; -use tokio::sync::{ - broadcast::{self, Receiver}, - mpsc::{unbounded_channel, UnboundedReceiver}, +use tokio::{ + net::TcpListener, + sync::{ + broadcast::{self, Receiver}, + mpsc::{unbounded_channel, UnboundedReceiver}, + }, }; use self::client::TestClient; diff --git a/tests/openid_login.rs b/tests/openid_login.rs index 141d13c7c..9cc017be1 100644 --- a/tests/openid_login.rs +++ b/tests/openid_login.rs @@ -55,6 +55,8 @@ async fn test_openid_providers() { directory_sync_admin_behavior: DirectorySyncUserBehavior::Keep.to_string(), directory_sync_target: DirectorySyncTarget::All.to_string(), create_account: false, + okta_dirsync_client_id: None, + okta_private_jwk: None, }; let response = client diff --git a/web/src/i18n/en/index.ts b/web/src/i18n/en/index.ts index 6f126b121..c5e16e605 100644 --- a/web/src/i18n/en/index.ts +++ b/web/src/i18n/en/index.ts @@ -1182,6 +1182,15 @@ Licensing information: [https://docs.defguard.net/enterprise/license](https://do uploaded: 'File uploaded', uploadPrompt: 'Upload a service account key file', }, + okta_client_id: { + label: 'Directory Sync Client ID', + helper: 'Client ID for the Okta directory sync application.', + }, + okta_client_key: { + label: 'Directory Sync Client Private Key', + helper: + "Client private key for the Okta directory sync application in the JWK format. It won't be shown again here.", + }, }, }, }, diff --git a/web/src/i18n/i18n-types.ts b/web/src/i18n/i18n-types.ts index 1a591f804..4c1e19f61 100644 --- a/web/src/i18n/i18n-types.ts +++ b/web/src/i18n/i18n-types.ts @@ -2872,6 +2872,26 @@ type RootTranslation = { */ uploadPrompt: string } + okta_client_id: { + /** + * D​i​r​e​c​t​o​r​y​ ​S​y​n​c​ ​C​l​i​e​n​t​ ​I​D + */ + label: string + /** + * C​l​i​e​n​t​ ​I​D​ ​f​o​r​ ​t​h​e​ ​O​k​t​a​ ​d​i​r​e​c​t​o​r​y​ ​s​y​n​c​ ​a​p​p​l​i​c​a​t​i​o​n​. + */ + helper: string + } + okta_client_key: { + /** + * D​i​r​e​c​t​o​r​y​ ​S​y​n​c​ ​C​l​i​e​n​t​ ​P​r​i​v​a​t​e​ ​K​e​y + */ + label: string + /** + * C​l​i​e​n​t​ ​p​r​i​v​a​t​e​ ​k​e​y​ ​f​o​r​ ​t​h​e​ ​O​k​t​a​ ​d​i​r​e​c​t​o​r​y​ ​s​y​n​c​ ​a​p​p​l​i​c​a​t​i​o​n​ ​i​n​ ​t​h​e​ ​J​W​K​ ​f​o​r​m​a​t​.​ ​I​t​ ​w​o​n​'​t​ ​b​e​ ​s​h​o​w​n​ ​a​g​a​i​n​ ​h​e​r​e​. + */ + helper: string + } } } } @@ -7675,6 +7695,26 @@ export type TranslationFunctions = { */ uploadPrompt: () => LocalizedString } + okta_client_id: { + /** + * Directory Sync Client ID + */ + label: () => LocalizedString + /** + * Client ID for the Okta directory sync application. + */ + helper: () => LocalizedString + } + okta_client_key: { + /** + * Directory Sync Client Private Key + */ + label: () => LocalizedString + /** + * Client private key for the Okta directory sync application in the JWK format. It won't be shown again here. + */ + helper: () => LocalizedString + } } } } diff --git a/web/src/i18n/pl/index.ts b/web/src/i18n/pl/index.ts index 3f3106702..21b2a62f1 100644 --- a/web/src/i18n/pl/index.ts +++ b/web/src/i18n/pl/index.ts @@ -1147,6 +1147,15 @@ Uwaga, podane tutaj konfiguracje nie posiadają klucza prywatnego. Musisz uzupe uploaded: 'Przesłany plik', uploadPrompt: 'Prześlij plik klucza konta usługi', }, + okta_client_id: { + label: 'ID klienta synchronizacji Okta', + helper: 'ID klienta dla aplikacji synchronizacji Okta.', + }, + okta_client_key: { + label: 'Klucz prywatny klienta synchronizacji Okta', + helper: + 'Klucz prywatny dla aplikacji synchronizacji Okta w formacie JWK. Klucz nie jest wyświetlany ponownie po wgraniu.', + }, }, }, }, diff --git a/web/src/pages/settings/components/OpenIdSettings/components/DirectorySyncSettings.tsx b/web/src/pages/settings/components/OpenIdSettings/components/DirectorySyncSettings.tsx index a4a79b36a..063941fa5 100644 --- a/web/src/pages/settings/components/OpenIdSettings/components/DirectorySyncSettings.tsx +++ b/web/src/pages/settings/components/OpenIdSettings/components/DirectorySyncSettings.tsx @@ -17,7 +17,7 @@ import useApi from '../../../../../shared/hooks/useApi'; import { useToaster } from '../../../../../shared/hooks/useToaster'; import { titleCase } from '../../../../../shared/utils/titleCase'; -const SUPPORTED_SYNC_PROVIDERS = ['Google', 'Microsoft']; +const SUPPORTED_SYNC_PROVIDERS = ['Google', 'Microsoft', 'Okta']; export const DirsyncSettings = ({ isLoading }: { isLoading: boolean }) => { const { LL } = useI18nContext(); @@ -150,6 +150,28 @@ export const DirsyncSettings = ({ isLoading }: { isLoading: boolean }) => { } disabled={isLoading} /> + {providerName === 'Okta' ? ( + <> + {parse(localLL.form.labels.okta_client_id.helper())} + } + required={dirsyncEnabled} + /> + {parse(localLL.form.labels.okta_client_key.helper())} + } + required={dirsyncEnabled} + /> + + ) : null} {providerName === 'Google' ? ( <> { label: 'Microsoft', key: 2, }, + { + value: 'Okta', + label: 'Okta', + key: 3, + }, { value: 'Custom', label: localLL.form.custom(), - key: 3, + key: 4, }, ], [localLL.form], @@ -65,6 +70,8 @@ export const OpenIdSettingsForm = ({ isLoading }: { isLoading: boolean }) => { return 'https://accounts.google.com'; case 'Microsoft': return `https://login.microsoftonline.com//v2.0`; + case 'Okta': + return ``; default: return null; } @@ -77,6 +84,8 @@ export const OpenIdSettingsForm = ({ isLoading }: { isLoading: boolean }) => { return 'Google'; case 'Microsoft': return 'Microsoft'; + case 'Okta': + return 'Okta'; default: return null; } diff --git a/web/src/pages/settings/components/OpenIdSettings/components/OpenIdSettingsRootForm.tsx b/web/src/pages/settings/components/OpenIdSettings/components/OpenIdSettingsRootForm.tsx index 394e05e27..b9998c232 100644 --- a/web/src/pages/settings/components/OpenIdSettings/components/OpenIdSettingsRootForm.tsx +++ b/web/src/pages/settings/components/OpenIdSettings/components/OpenIdSettingsRootForm.tsx @@ -96,6 +96,8 @@ export const OpenIdSettingsRootForm = () => { directory_sync_admin_behavior: z.string(), directory_sync_target: z.string(), create_account: z.boolean(), + okta_private_jwk: z.string(), + okta_dirsync_client_id: z.string(), }) .superRefine((val, ctx) => { if (val.name === '') { @@ -106,6 +108,16 @@ export const OpenIdSettingsRootForm = () => { }); } + if (val.directory_sync_enabled && val.base_url.includes('okta')) { + if (val.okta_dirsync_client_id.length === 0) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: LL.form.error.required(), + path: ['okta_dirsync_client_id'], + }); + } + } + if (val.directory_sync_enabled && val.name === 'Google') { if (val.admin_email.length === 0) { ctx.addIssue({ @@ -144,6 +156,8 @@ export const OpenIdSettingsRootForm = () => { directory_sync_admin_behavior: 'keep', directory_sync_target: 'all', create_account: false, + okta_private_jwk: '', + okta_dirsync_client_id: '', }; if (openidData) { diff --git a/web/src/shared/types.ts b/web/src/shared/types.ts index 0201dd664..131d51ac8 100644 --- a/web/src/shared/types.ts +++ b/web/src/shared/types.ts @@ -909,6 +909,8 @@ export interface OpenIdProvider { directory_sync_user_behavior: 'keep' | 'disable' | 'delete'; directory_sync_admin_behavior: 'keep' | 'disable' | 'delete'; directory_sync_target: 'all' | 'users' | 'groups'; + okta_private_jwk?: string; + okta_dirsync_client_id?: string; } export interface EditOpenidClientRequest {