Skip to content

Commit

Permalink
Vermadhr/keyless access work (#23554)
Browse files Browse the repository at this point in the history
## Description

This PR adds a fix for a bug introduced in
#23407.
#23407 removed internal
`getKey` API calls. However, the alternative to sign tokens by Riddler
was not working for cases where we spoof user tokens (Scribe -> Alfred,
Scribe -> Historian). This is because the token after expiry was never
refreshed.

This PR fixes this by adding a new callback to the `BasicRestWrapper` -
`refreshTokenIfNeeded`. This callback can handle refreshing auth tokens
if provided.

## Breaking Changes

This change should not have breaking changes.

## Reviewer Guidance

1) Unit testing
2) Code changes that refresh auth tokens in the new callback.

---------

Co-authored-by: Alex Villarreal <716334+alexvy86@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 15, 2025
1 parent fb47218 commit b947c9c
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ export class TenantManager {
const lumberProperties = {
[BaseTelemetryProperties.tenantId]: tenantId,
includeDisabledTenant,
documentId,
scopes,
lifetime,
ver,
jti,
};
const tenantDocument = await this.getTenantDocument(tenantId, includeDisabledTenant);
if (tenantDocument === undefined) {
Expand Down Expand Up @@ -200,6 +205,10 @@ export class TenantManager {
).key1
: keys.key1;

Lumberjack.info("Signing token with key1", {
...lumberProperties,
isTenantPrivateKeyAccessEnabled,
});
const token = generateToken(
tenantId,
documentId,
Expand All @@ -211,6 +220,10 @@ export class TenantManager {
jti,
isTenantPrivateKeyAccessEnabled,
);
Lumberjack.info("Token signed with key1", {
...lumberProperties,
isTenantPrivateKeyAccessEnabled,
});

return {
fluidAccessToken: token,
Expand All @@ -236,11 +249,12 @@ export class TenantManager {
const lumberProperties = {
[BaseTelemetryProperties.tenantId]: tenantId,
includeDisabledTenant,
isKeylessAccessValidation,
};

// Try validating with Key 1
try {
await this.validateTokenWithKey(tenantKeys.key1, KeyName.key1, token);
await this.validateTokenWithKey(tenantKeys.key1, KeyName.key1, token, lumberProperties);
return;
} catch (error) {
if (isNetworkError(error)) {
Expand Down Expand Up @@ -272,7 +286,7 @@ export class TenantManager {
}
// If Key 1 validation fails, try with Key 2
try {
await this.validateTokenWithKey(tenantKeys.key2, KeyName.key2, token);
await this.validateTokenWithKey(tenantKeys.key2, KeyName.key2, token, lumberProperties);
} catch (error) {
if (isNetworkError(error)) {
if (error.code === 403) {
Expand Down Expand Up @@ -313,6 +327,7 @@ export class TenantManager {
key: string,
keyName: string,
token: string,
lumberjackProperties: any = {},
): Promise<boolean> {
return new Promise<boolean>((resolve, reject) => {
jwt.verify(token, key, (error) => {
Expand All @@ -324,8 +339,16 @@ export class TenantManager {
// When `exp` claim exists in token claims, jsonwebtoken verifies token expiration.

if (error instanceof jwt.TokenExpiredError) {
Lumberjack.error(
`Token expired validated with ${keyName}.`,
lumberjackProperties,
);
reject(new NetworkError(401, `Token expired validated with ${keyName}.`));
} else {
Lumberjack.error(
`Invalid token validated with ${keyName}.`,
lumberjackProperties,
);
reject(new NetworkError(403, `Invalid token validated with ${keyName}.`));
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { SummaryObject } from '@fluidframework/protocol-definitions';

// @internal (undocumented)
export class BasicRestWrapper extends RestWrapper {
constructor(baseurl?: string, defaultQueryString?: Record<string, string | number | boolean>, maxBodyLength?: number, maxContentLength?: number, defaultHeaders?: RawAxiosRequestHeaders, axios?: AxiosInstance, refreshDefaultQueryString?: (() => Record<string, string | number | boolean>) | undefined, refreshDefaultHeaders?: (() => RawAxiosRequestHeaders) | undefined, getCorrelationId?: (() => string | undefined) | undefined, getTelemetryContextProperties?: (() => Record<string, string | number | boolean> | undefined) | undefined);
constructor(baseurl?: string, defaultQueryString?: Record<string, string | number | boolean>, maxBodyLength?: number, maxContentLength?: number, defaultHeaders?: RawAxiosRequestHeaders, axios?: AxiosInstance, refreshDefaultQueryString?: (() => Record<string, string | number | boolean>) | undefined, refreshDefaultHeaders?: (() => RawAxiosRequestHeaders) | undefined, getCorrelationId?: (() => string | undefined) | undefined, getTelemetryContextProperties?: (() => Record<string, string | number | boolean> | undefined) | undefined, refreshTokenIfNeeded?: ((authorizationHeader: RawAxiosRequestHeaders) => Promise<RawAxiosRequestHeaders | undefined>) | undefined);
// (undocumented)
protected request<T>(requestConfig: AxiosRequestConfig, statusCode: number, canRetry?: boolean): Promise<T>;
}
Expand Down Expand Up @@ -595,6 +595,9 @@ export class NetworkError extends Error {
};
}

// @internal (undocumented)
export function parseToken(tenantId: string, authorization: string | undefined): string | undefined;

// @internal (undocumented)
export function promiseTimeout(mSec: number, promise: Promise<any>): Promise<any>;

Expand Down
30 changes: 30 additions & 0 deletions server/routerlicious/packages/services-client/src/historian.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import * as git from "@fluidframework/gitresources";
import { RestWrapper, BasicRestWrapper } from "./restWrapper";
import { IHistorian } from "./storage";
import { IWholeFlatSummary, IWholeSummaryPayload, IWriteSummaryResponse } from "./storageContracts";
import { NetworkError } from "./error";
import { debug } from "./debug";

function endsWith(value: string, endings: string[]): boolean {
for (const ending of endings) {
Expand All @@ -33,6 +35,34 @@ export interface ICredentials {
export const getAuthorizationTokenFromCredentials = (credentials: ICredentials): string =>
`Basic ${fromUtf8ToBase64(`${credentials.user}:${credentials.password}`)}`;

/**
* @internal
*/
export function parseToken(
tenantId: string,
authorization: string | undefined,
): string | undefined {
let token: string | undefined;
if (authorization) {
const base64TokenMatch = authorization.match(/Basic (.+)/);
if (!base64TokenMatch) {
debug("Invalid base64 token", { tenantId });
throw new NetworkError(403, "Malformed authorization token");
}
const encoded = Buffer.from(base64TokenMatch[1], "base64").toString();

const tokenMatch = encoded.match(/(.+):(.+)/);
if (!tokenMatch || tenantId !== tokenMatch[1]) {
debug("Tenant mismatch or invalid token format", { tenantId });
throw new NetworkError(403, "Malformed authorization token");
}

token = tokenMatch[2];
}

return token;
}

/**
* Implementation of the IHistorian interface that calls out to a REST interface
* @internal
Expand Down
7 changes: 6 additions & 1 deletion server/routerlicious/packages/services-client/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ export {
export { choose, getRandomName } from "./generateNames";
export { GitManager } from "./gitManager";
export { Heap, IHeapComparator } from "./heap";
export { getAuthorizationTokenFromCredentials, Historian, ICredentials } from "./historian";
export {
getAuthorizationTokenFromCredentials,
Historian,
ICredentials,
parseToken,
} from "./historian";
export { IAlfredTenant, ISession } from "./interfaces";
export { promiseTimeout } from "./promiseTimeout";
export { RestLessClient, RestLessFieldNames } from "./restLessClient";
Expand Down
18 changes: 18 additions & 0 deletions server/routerlicious/packages/services-client/src/restWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ export class BasicRestWrapper extends RestWrapper {
private readonly getTelemetryContextProperties?: () =>
| Record<string, string | number | boolean>
| undefined,
private readonly refreshTokenIfNeeded?: (
authorizationHeader: RawAxiosRequestHeaders,
) => Promise<RawAxiosRequestHeaders | undefined>,
) {
super(baseurl, defaultQueryString, maxBodyLength, maxContentLength);
}
Expand All @@ -183,6 +186,21 @@ export class BasicRestWrapper extends RestWrapper {
this.getTelemetryContextProperties?.(),
);

// If the request has an Authorization header and a refresh token function is provided, try to refresh the token if needed
if (options.headers?.Authorization && this.refreshTokenIfNeeded) {
const refreshedToken = await this.refreshTokenIfNeeded(options.headers).catch(
(error) => {
debug(`request to ${options.url} failed ${error ? error.message : ""}`);
throw error;
},
);
if (refreshedToken) {
options.headers.Authorization = refreshedToken.Authorization;
// Update the default headers to use the refreshed token
this.defaultHeaders.Authorization = refreshedToken.Authorization;
}
}

return new Promise<T>((resolve, reject) => {
this.axios
.request<T>(options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { AxiosError, AxiosInstance, AxiosRequestConfig, AxiosResponse } from "ax
import AxiosMockAdapter from "axios-mock-adapter";
import { CorrelationIdHeaderName } from "../constants";
import { BasicRestWrapper } from "../restWrapper";
import { KJUR as jsrsasign } from "jsrsasign";
import { jwtDecode } from "jwt-decode";

describe("BasicRestWrapper", () => {
const baseurl = "https://fake.microsoft.com";
Expand Down Expand Up @@ -893,4 +895,61 @@ describe("BasicRestWrapper", () => {
);
});
});

describe("Token refresh", () => {
it("Token should be refreshed if callback is provided", async () => {
const key = "1234";
const expiredToken = jsrsasign.jws.JWS.sign(
null,
JSON.stringify({ alg: "HS256", typ: "JWT" }),
{ exp: Math.round(new Date().getTime() / 1000) - 100 },
key,
);
const getDefaultHeaders = () => {
return {
Authorization: `Basic ${expiredToken}`,
};
};
const newToken = jsrsasign.jws.JWS.sign(
null,
JSON.stringify({ alg: "HS256", typ: "JWT" }),
{ exp: Math.round(new Date().getTime() / 1000) + 10000 },
key,
);

const refreshTokenIfNeeded = async () => {
const tokenClaims = jwtDecode(expiredToken);
if (tokenClaims.exp < new Date().getTime() / 1000) {
return {
Authorization: `Basic ${newToken}`,
};
} else {
return undefined;
}
};

const rw = new BasicRestWrapper(
baseurl,
{},
maxBodyLength,
maxContentLength,
getDefaultHeaders(),
axiosMock as AxiosInstance,
undefined,
undefined,
undefined,
undefined,
refreshTokenIfNeeded,
);

//act
await rw.get(requestUrl).then(
// tslint:disable-next-line:no-void-expression
() => assert.ok(true),
);

assert.notEqual(rw["defaultHeaders"].Authorization, `Basic ${expiredToken}`);
assert.strictEqual(rw["defaultHeaders"].Authorization, `Basic ${newToken}`);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"rootDir": "./src",
"outDir": "./dist",
"composite": true,
"types": ["node"],
},
"include": ["src/**/*"],
}
43 changes: 39 additions & 4 deletions server/routerlicious/packages/services-utils/src/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ import {
DocDeleteScopeType,
getGlobalTimeoutContext,
} from "@fluidframework/server-services-client";
import type {
ICache,
IRevokedTokenChecker,
ITenantManager,
import {
requestWithRetry,
type ICache,
type IRevokedTokenChecker,
type ITenantManager,
} from "@fluidframework/server-services-core";
import type { RequestHandler, Request, Response } from "express";
import type { Provider } from "nconf";
Expand Down Expand Up @@ -180,6 +181,10 @@ function getTokenFromRequest(request: Request): string {
if (!authorizationHeader) {
throw new NetworkError(403, "Missing Authorization header.");
}
return extractTokenFromHeader(authorizationHeader);
}

export function extractTokenFromHeader(authorizationHeader: string): string {
const tokenRegex = /Basic (.+)/;
const tokenMatch = tokenRegex.exec(authorizationHeader);
if (!tokenMatch?.[1]) {
Expand All @@ -188,6 +193,36 @@ function getTokenFromRequest(request: Request): string {
return tokenMatch[1];
}

// Returns true if the token is valid for at least 5 minutes.
export function isTokenValid(token: string): boolean {
const tokenClaims = decode(token) as ITokenClaims;
const lifeTimeMSec = tokenClaims.exp * 1000 - new Date().getTime();
return lifeTimeMSec > 5 * 60 * 1000; // 5 minutes
}

export async function getValidAccessToken(
currentAccessToken: string,
tenantManager: ITenantManager,
tenantId: string,
documentId: string,
scopes: ScopeType[],
lumberProperties: Record<string, any>,
): Promise<string | undefined> {
// If the current token is still valid, return undefined
if (isTokenValid(currentAccessToken)) {
Lumberjack.verbose(`Token is still valid`, lumberProperties);
return undefined;
}
Lumberjack.info(`Refreshing token`, lumberProperties);

const newToken = await requestWithRetry(
async () => tenantManager.signToken(tenantId, documentId, scopes),
`getValidAccessToken_signToken` /* callName */,
lumberProperties /* telemetryProperties */,
);
return newToken;
}

const defaultMaxTokenLifetimeSec = 60 * 60; // 1 hour

/**
Expand Down
3 changes: 3 additions & 0 deletions server/routerlicious/packages/services-utils/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ export {
verifyStorageToken,
validateTokenScopeClaims,
verifyToken,
isTokenValid,
extractTokenFromHeader,
getValidAccessToken,
} from "./auth";
export { getBooleanFromConfig, getNumberFromConfig } from "./configUtils";
export { parseBoolean } from "./conversion";
Expand Down
Loading

0 comments on commit b947c9c

Please sign in to comment.