Skip to content

Commit

Permalink
fix: tos_uri validation (#3945)
Browse files Browse the repository at this point in the history
Contributes to ory-corp/cloud#7395

---------

Co-authored-by: Arne Luenser <arne.luenser@ory.sh>
  • Loading branch information
hperl and alnr authored Feb 20, 2025
1 parent 6ae0552 commit 007e224
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
9 changes: 6 additions & 3 deletions client/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func createTestClient(prefix string) hydra.OAuth2Client {
Owner: pointerx.Ptr(prefix + "an-owner"),
PolicyUri: pointerx.Ptr(prefix + "policy-uri"),
Scope: pointerx.Ptr(prefix + "foo bar baz"),
TosUri: pointerx.Ptr(prefix + "tos-uri"),
TosUri: pointerx.Ptr("https://example.org/" + prefix + "tos"),
ResponseTypes: []string{prefix + "id_token", prefix + "code"},
RedirectUris: []string{"https://" + prefix + "redirect-url", "https://" + prefix + "redirect-uri"},
ClientSecretExpiresAt: pointerx.Ptr[int64](0),
Expand Down Expand Up @@ -95,12 +95,15 @@ func TestClientSDK(t *testing.T) {
// createClient.SecretExpiresAt = 10

// returned client is correct on Create
result, _, err := c.OAuth2API.CreateOAuth2Client(ctx).OAuth2Client(createClient).Execute()
require.NoError(t, err)
result, res, err := c.OAuth2API.CreateOAuth2Client(ctx).OAuth2Client(createClient).Execute()
if !assert.NoError(t, err) {
t.Fatalf("error: %s", ioutilx.MustReadAll(res.Body))
}
assert.NotEmpty(t, result.UpdatedAt)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.RegistrationAccessToken)
assert.NotEmpty(t, result.RegistrationClientUri)
assert.NotEmpty(t, *result.TosUri)
assert.NotEmpty(t, result.ClientId)
createClient.ClientId = result.ClientId

Expand Down
12 changes: 12 additions & 0 deletions client/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ func (v *Validator) Validate(ctx context.Context, c *Client) error {
}
}

if c.TermsOfServiceURI != "" {
u, err := url.ParseRequestURI(c.TermsOfServiceURI)
if err != nil {
return errorsx.WithStack(ErrInvalidClientMetadata.WithHint("Field tos_uri must be a valid URI."))
}

if u.Scheme != "https" && u.Scheme != "http" {
return errorsx.WithStack(ErrInvalidClientMetadata.WithHintf("tos_uri %s must use https:// or http:// as HTTP scheme.", c.TermsOfServiceURI))
}

}

if len(c.Secret) > 0 && len(c.Secret) < 6 {
return errorsx.WithStack(ErrInvalidClientMetadata.WithHint("Field client_secret must contain a secret that is at least 6 characters long."))
}
Expand Down
16 changes: 11 additions & 5 deletions client/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ func TestValidate(t *testing.T) {
reg := testhelpers.NewRegistryMemory(t, c, &contextx.Static{C: c.Source(ctx)})
v := NewValidator(reg)

testCtx := context.TODO()

dec := json.NewDecoder(strings.NewReader(validJWKS))
dec.DisallowUnknownFields()
var goodJWKS jose.JSONWebKeySet
Expand Down Expand Up @@ -130,6 +128,14 @@ func TestValidate(t *testing.T) {
assert.Equal(t, []string{"https://foo/"}, []string(c.PostLogoutRedirectURIs))
},
},
{
in: &Client{ID: "foo", TermsOfServiceURI: "https://example.org"},
assertErr: assert.NoError,
},
{
in: &Client{ID: "foo", TermsOfServiceURI: "javascript:alert('XSS')"},
assertErr: assert.Error,
},
{
in: &Client{ID: "foo"},
check: func(t *testing.T, c *Client) {
Expand Down Expand Up @@ -164,7 +170,7 @@ func TestValidate(t *testing.T) {
return v
}
}
err := tc.v(t).Validate(testCtx, tc.in)
err := tc.v(t).Validate(ctx, tc.in)
if tc.assertErr != nil {
tc.assertErr(t, err)
} else {
Expand All @@ -180,7 +186,7 @@ type fakeHTTP struct {
c *http.Client
}

func (f *fakeHTTP) HTTPClient(ctx context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client {
func (f *fakeHTTP) HTTPClient(_ context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client {
c := httpx.NewResilientClient(opts...)
c.HTTPClient = f.c
return c
Expand All @@ -191,7 +197,7 @@ func TestValidateSectorIdentifierURL(t *testing.T) {
var payload string

var h http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(payload))
_, _ = w.Write([]byte(payload))
}
ts := httptest.NewTLSServer(h)
defer ts.Close()
Expand Down

0 comments on commit 007e224

Please sign in to comment.