Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix auto head response feature (#1835). #1838

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions ktor-http/api/ktor-http.api
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,17 @@ public final class io/ktor/http/HttpMessagePropertiesKt {

public final class io/ktor/http/HttpMethod {
public static final field Companion Lio/ktor/http/HttpMethod$Companion;
public fun <init> (Ljava/lang/String;)V
public fun <init> (Ljava/lang/String;Ljava/util/List;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/util/List;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun component1 ()Ljava/lang/String;
public final fun copy (Ljava/lang/String;)Lio/ktor/http/HttpMethod;
public static synthetic fun copy$default (Lio/ktor/http/HttpMethod;Ljava/lang/String;ILjava/lang/Object;)Lio/ktor/http/HttpMethod;
public final fun component2 ()Ljava/util/List;
public final fun copy (Ljava/lang/String;Ljava/util/List;)Lio/ktor/http/HttpMethod;
public static synthetic fun copy$default (Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/util/List;ILjava/lang/Object;)Lio/ktor/http/HttpMethod;
public fun equals (Ljava/lang/Object;)Z
public final fun getAggregate ()Ljava/util/List;
public final fun getValue ()Ljava/lang/String;
public fun hashCode ()I
public final fun match (Lio/ktor/http/HttpMethod;)Z
public fun toString ()Ljava/lang/String;
}

Expand Down
16 changes: 15 additions & 1 deletion ktor-http/common/src/io/ktor/http/HttpMethod.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,25 @@

package io.ktor.http

import io.ktor.util.*

/**
* Represents an HTTP method (verb)
* @property value contains method name
*/
data class HttpMethod(val value: String) {
data class HttpMethod(val value: String, @InternalAPI val aggregate: List<HttpMethod> = listOf()) {
/**
* Checks if the specified HTTP [method] matches this instance of the HTTP method. Specified method matches if it's
* equal to this HTTP method or at least one of methods this HTTP method aggregates.
*/
fun match(method: HttpMethod): Boolean {
if (this == method) {
return true
}

return aggregate.contains(method) || method.aggregate.contains(this)
}

@Suppress("KDocMissingDocumentation", "PublicApiImplicitType")
companion object {
val Get = HttpMethod("GET")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ object AutoHeadResponse : ApplicationFeature<ApplicationCallPipeline, Unit, Unit

// Pretend the request was with GET method so that all normal routes and interceptors work
// but in the end we will drop the content
call.mutableOriginConnectionPoint.method = HttpMethod.Get
call.mutableOriginConnectionPoint.method = HttpMethod(
"GET_OR_HEAD",
listOf(HttpMethod.Get, HttpMethod.Head)
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ data class AndRouteSelector(val first: RouteSelector, val second: RouteSelector)
*/
data class HttpMethodRouteSelector(val method: HttpMethod) : RouteSelector(RouteSelectorEvaluation.qualityParameter) {
override fun evaluate(context: RoutingResolveContext, segmentIndex: Int): RouteSelectorEvaluation {
if (context.call.request.httpMethod == method)
if (context.call.request.httpMethod.match(method))
return RouteSelectorEvaluation.Constant
return RouteSelectorEvaluation.Failed
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class HeadTest {
call.response.header("M", "1")
call.respond("Hello")
}

head("/head") {
call.respond(HttpStatusCode.OK)
}
}

handleRequest(HttpMethod.Get, "/").let { call ->
Expand All @@ -39,6 +43,11 @@ class HeadTest {
assertNull(call.response.content)
assertEquals("1", call.response.headers["M"])
}

handleRequest(HttpMethod.Head, "/head").let { call ->
assertEquals(HttpStatusCode.OK, call.response.status())
assertNull(call.response.content)
}
}
}

Expand Down