Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ val authorName = "ccbluex"
val projectUrl = "https://github.com/ccbluex/netty-httpserver"

group = "net.ccbluex"
version = "2.4.2"
version = "2.4.3-alpha.1"

repositories {
mavenCentral()
Expand All @@ -31,8 +31,10 @@ dependencies {
api(libs.bundles.netty)
api(libs.gson)
api(libs.tika.core)
api(libs.coroutines.core)

testImplementation(kotlin("test"))
testImplementation(libs.coroutines.test)
testImplementation("com.squareup.retrofit2:retrofit:2.9.0")
testImplementation("com.squareup.retrofit2:converter-gson:2.9.0")
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/net/ccbluex/netty/http/HttpConductor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import net.ccbluex.netty.http.util.httpNoContent
* @param context The request context to process.
* @return The response to the request.
*/
internal fun HttpServer.processRequestContext(context: RequestContext) = runCatching {
internal suspend fun HttpServer.processRequestContext(context: RequestContext) = runCatching {
val content = context.contentBuffer.toByteArray()
val method = context.httpMethod

Expand Down
54 changes: 44 additions & 10 deletions src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ import io.netty.handler.codec.http.HttpHeaderNames
import io.netty.handler.codec.http.HttpRequest
import io.netty.handler.codec.http.LastHttpContent
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.cancel
import kotlinx.coroutines.launch
import net.ccbluex.netty.http.HttpServer.Companion.logger
import net.ccbluex.netty.http.middleware.Middleware
import net.ccbluex.netty.http.model.RequestContext
import net.ccbluex.netty.http.util.forEachIsInstance
import net.ccbluex.netty.http.websocket.WebSocketHandler
import java.net.URLDecoder

Expand All @@ -40,13 +47,37 @@ import java.net.URLDecoder
internal class HttpServerHandler(private val server: HttpServer) : ChannelInboundHandlerAdapter() {

private val localRequestContext = ThreadLocal<RequestContext>()
private lateinit var channelScope: CoroutineScope

/**
* Extension property to get the WebSocket URL from an HttpRequest.
*/
private val HttpRequest.webSocketUrl: String
get() = "ws://${headers().get("Host")}${uri()}"

/**
* Adds the [CoroutineScope] of current [io.netty.channel.Channel].
*/
override fun handlerAdded(ctx: ChannelHandlerContext) {
super.handlerAdded(ctx)

val exceptionHandler = CoroutineExceptionHandler { _, throwable ->
val ctxName = ctx.name()
val channelId = ctx.channel().id().asLongText()
logger.error(
"Uncaught coroutine error in [ctx: $ctxName, channel: $channelId]",
throwable
)
}

channelScope = CoroutineScope(
ctx.channel().eventLoop().asCoroutineDispatcher()
+ CoroutineName("${ctx.name()}#${ctx.channel().id().asShortText()}")
+ exceptionHandler
)
ctx.channel().closeFuture().addListener { channelScope.cancel() }
}

/**
* Reads the incoming messages and processes HTTP requests.
*
Expand All @@ -68,11 +99,11 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun
if (connection.equals("Upgrade", ignoreCase = true) &&
upgrade.equals("WebSocket", ignoreCase = true)) {

server.middlewares.filterIsInstance<Middleware.OnWebSocketUpgrade>().forEach { middleware ->
server.middlewares.forEachIsInstance<Middleware.OnWebSocketUpgrade> { middleware ->
val response = middleware.invoke(ctx, msg)
if (response != null) {
ctx.writeAndFlush(response)
return
return super.channelRead(ctx, msg)
}
}

Expand All @@ -99,15 +130,15 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun
URLDecoder.decode(msg.uri(), Charsets.UTF_8),
msg.headers(),
)

localRequestContext.set(requestContext)
}
}

is HttpContent -> {
val requestContext = localRequestContext.get() ?: run {
logger.warn("Received HttpContent without HttpRequest")
return
return super.channelRead(ctx, msg)
}

// Append content to the buffer
Expand All @@ -117,18 +148,21 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun
if (msg is LastHttpContent) {
localRequestContext.remove()

server.middlewares.filterIsInstance<Middleware.OnRequest>().forEach { middleware ->
server.middlewares.forEachIsInstance<Middleware.OnRequest> { middleware ->
val response = middleware.invoke(requestContext)
if (response != null) {
ctx.writeAndFlush(response)
return
return super.channelRead(ctx, msg)
}
}
var response = server.processRequestContext(requestContext)
server.middlewares.filterIsInstance<Middleware.OnResponse>().forEach { middleware ->
response = middleware.invoke(requestContext, response)

channelScope.launch {
var response = server.processRequestContext(requestContext)
server.middlewares.forEachIsInstance<Middleware.OnResponse> { middleware ->
response = middleware.invoke(requestContext, response)
}
ctx.writeAndFlush(response)
}
ctx.writeAndFlush(response)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package net.ccbluex.netty.http.coroutines

import io.netty.util.concurrent.Future
import io.netty.util.concurrent.GenericFutureListener
import kotlinx.coroutines.CancellableContinuation
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.suspendCancellableCoroutine

/**
* Suspend until this Netty Future completes.
*
* Returns the Future result. Throws on failure or cancellation.
*/
suspend fun <V, F : Future<V>> F.suspend(): V {
if (isDone) return unwrapDone().getOrThrow()

return suspendCancellableCoroutine { cont ->
addListener(futureContinuationListener(cont))

cont.invokeOnCancellation {
this.cancel(false)
}
}
}

private fun <V, F : Future<V>> futureContinuationListener(
cont: CancellableContinuation<V>
): GenericFutureListener<F> = GenericFutureListener { future ->
if (cont.isActive) {
cont.resumeWith(future.unwrapDone())
}
}

private fun <V, F : Future<V>> F.unwrapDone(): Result<V> =
when {
isSuccess -> Result.success(this.now)
isCancelled -> Result.failure(CancellationException("Netty Future was cancelled"))
else -> Result.failure(
this.cause() ?: IllegalStateException("Future failed without cause")
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package net.ccbluex.netty.http.model
import io.netty.handler.codec.http.FullHttpResponse

fun interface RequestHandler {
fun handle(request: RequestObject): FullHttpResponse
suspend fun handle(request: RequestObject): FullHttpResponse
}
2 changes: 1 addition & 1 deletion src/main/kotlin/net/ccbluex/netty/http/rest/FileServant.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class FileServant(part: String, private val baseFolder: File) : Node(part) {

override val isExecutable = true

override fun handle(request: RequestObject): FullHttpResponse {
override suspend fun handle(request: RequestObject): FullHttpResponse {
val path = request.remainingPath
val sanitizedPath = path.replace("..", "")
val file = baseFolder.resolve(sanitizedPath)
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/net/ccbluex/netty/http/rest/Node.kt
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ open class Node(val part: String) : RequestHandler {
* @param request The request object.
* @return The HTTP response.
*/
override fun handle(request: RequestObject): FullHttpResponse = throw NotImplementedError()
override suspend fun handle(request: RequestObject): FullHttpResponse = throw NotImplementedError()

/**
* Checks if the node matches a part of the path and HTTP method.
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/net/ccbluex/netty/http/rest/Route.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import net.ccbluex.netty.http.model.RequestObject
open class Route(name: String, private val method: HttpMethod, val handler: RequestHandler)
: Node(name) {
override val isExecutable = true
override fun handle(request: RequestObject) = handler.handle(request)
override suspend fun handle(request: RequestObject) = handler.handle(request)
override fun matchesMethod(method: HttpMethod) =
this.method == method && super.matchesMethod(method)

Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/net/ccbluex/netty/http/rest/ZipServant.kt
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ZipServant(part: String, zipInputStream: InputStream) : Node(part) {
return files
}

override fun handle(request: RequestObject): FullHttpResponse {
override suspend fun handle(request: RequestObject): FullHttpResponse {
val path = request.remainingPath.removePrefix("/")
val cleanPath = path.substringBefore("?")
val sanitizedPath = cleanPath.replace("..", "")
Expand Down
9 changes: 9 additions & 0 deletions src/main/kotlin/net/ccbluex/netty/http/util/IterableUtils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package net.ccbluex.netty.http.util

inline fun <reified E> Iterable<*>.forEachIsInstance(action: (E) -> Unit) {
for (it in this) {
if (it is E) {
action(it)
}
}
}
31 changes: 16 additions & 15 deletions src/test/kotlin/ZipServantTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io.netty.handler.codec.http.EmptyHttpHeaders
import io.netty.handler.codec.http.HttpHeaders
import io.netty.handler.codec.http.HttpMethod
import io.netty.handler.codec.http.HttpResponseStatus
import kotlinx.coroutines.test.runTest
import net.ccbluex.netty.http.model.RequestObject
import net.ccbluex.netty.http.rest.ZipServant
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -109,7 +110,7 @@ class ZipServantTest {
}

@Test
fun `should serve index html for root path`() {
fun `should serve index html for root path`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -123,7 +124,7 @@ class ZipServantTest {
}

@Test
fun `should serve index html for slash path`() {
fun `should serve index html for slash path`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -134,7 +135,7 @@ class ZipServantTest {
}

@Test
fun `should serve specific files with correct content types`() {
fun `should serve specific files with correct content types`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand Down Expand Up @@ -162,7 +163,7 @@ class ZipServantTest {
}

@Test
fun `should handle files with dot-slash prefix`() {
fun `should handle files with dot-slash prefix`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -175,7 +176,7 @@ class ZipServantTest {
}

@Test
fun `should return 404 for non-existent files`() {
fun `should return 404 for non-existent files`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -185,7 +186,7 @@ class ZipServantTest {
}

@Test
fun `should sanitize path traversal attempts`() {
fun `should sanitize path traversal attempts`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -195,7 +196,7 @@ class ZipServantTest {
}

@Test
fun `should handle paths without leading slash`() {
fun `should handle paths without leading slash`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -206,7 +207,7 @@ class ZipServantTest {
}

@Test
fun `should serve index html for SPA routes with hash fragments`() {
fun `should serve index html for SPA routes with hash fragments`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -219,7 +220,7 @@ class ZipServantTest {
}

@Test
fun `should handle unknown file extensions with default content type`() {
fun `should handle unknown file extensions with default content type`() = runTest {
val baos = ByteArrayOutputStream()
ZipOutputStream(baos).use { zos ->
zos.putNextEntry(ZipEntry("test.unknown"))
Expand All @@ -235,7 +236,7 @@ class ZipServantTest {
}

@Test
fun `should handle various content types correctly`() {
fun `should handle various content types correctly`() = runTest {
val baos = ByteArrayOutputStream()
ZipOutputStream(baos).use { zos ->
// Test various file types - using Tika's expected content types
Expand Down Expand Up @@ -281,7 +282,7 @@ class ZipServantTest {
}

@Test
fun `should handle empty zip file gracefully`() {
fun `should handle empty zip file gracefully`() = runTest {
val baos = ByteArrayOutputStream()
ZipOutputStream(baos).use { /* empty zip */ }

Expand All @@ -292,7 +293,7 @@ class ZipServantTest {
}

@Test
fun `should serve index html for directory paths with trailing slash`() {
fun `should serve index html for directory paths with trailing slash`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -312,7 +313,7 @@ class ZipServantTest {
}

@Test
fun `should serve index html for SPA routes with fragments`() {
fun `should serve index html for SPA routes with fragments`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -339,7 +340,7 @@ class ZipServantTest {
}

@Test
fun `should serve index html for implicit directory access`() {
fun `should serve index html for implicit directory access`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand All @@ -359,7 +360,7 @@ class ZipServantTest {
}

@Test
fun `should return 404 for directory without index html`() {
fun `should return 404 for directory without index html`() = runTest {
val zipData = createTestZip()
val zipServant = ZipServant("static", zipData.inputStream())

Expand Down