Kotlinx.coroutines: Accessing throw-site MDC context in "global" exception handler

Created on 3 Nov 2020  路  10Comments  路  Source: Kotlin/kotlinx.coroutines

I am trying to enrich an application with some MDC context in order to make debugging issues easier.
My goal is to be able to log the MDC context from the point where an exception was thrown alongside the error.

In a non-coroutine world, that is easily possible by only popping the MDC entries when your block of code finishes successfully, but not on error. Your catch on the outermost level or a global exception handler would then still be able to "see" the MDC context from the throw site because it is left intact.

In a coroutine world, we have to wrap our suspend functions with withContext(MDCContext()) {} in order to bridge the ThreadLocal gap. But that means that whenever the coroutine resumes, the MDCContext will be restored to what it was when the MDCContext() was instantiated.

I tried using CoroutineExceptionHandler and supervisorScope to achieve the desired behaviour because I thought the coroutineContext passed to the CoroutineExceptionHandler would refer to the coroutine that failed (so I could get the MDCContext element in there), but it seems to be the supervisor job's coroutineContext.

I _could_ wrap every body inside withContext(MDCContext()) with a try/catch, log the error there and then rethrow, but that would mean I log the exception N times for N levels of nested MDCContexts, which is unsatisfying.

Any help on how I could achieve this will be greatly appreciated, I am out of ideas.

Please take a look at the following sample code to get an idea of what I am trying to achieve:

import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.slf4j.MDCContext
import kotlinx.coroutines.withContext
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.slf4j.MDC

private val logger: Logger = LoggerFactory.getLogger("test")

suspend fun deepInSomeOtherCode() {
    MDC.put("nesting", "deeper")
    withContext(MDCContext()) {
        logger.info("deep in the code")
        throw Exception("Oh noes, a bug!")
    }
}

fun main() {
    runBlocking {
        try {
            MDC.put("additional", "info")
            withContext(MDCContext()) {
                // do something
                logger.info("I swear I am useful")
                deepInSomeOtherCode()
            }
        } catch (e: Exception) {
            // I want to be able to log the MDC Context from
            // where the exception occurred (the innermost coroutine context)
            // alongside with this error message. So MDC context should be
            // "additional": "info"
            // "nesting": "deeper"
            // but with this code, MDC will only contain "additional": "info"
            // because MDCContext() will always install the MDC the coroutine was created with
            // upon resuming, so you will always end up with the MDC from the outermost MDCContext()
            logger.error("An error occurred", e)
        }
    }
}
question waiting for clarification

All 10 comments

I'd suggest to enrich the exception with the context. This can be done by implementating a function like this and using it every time you need to adjust the context. Something like this:

suspend fun <T> withMDC(key: String, value: String, block: suspend CoroutineScope.() -> T): T {
    val newMap = MDC.getCopyOfContextMap() + mapOf(key, value)
    return try { 
        withContext(MDCContext(newMap), block)
    } catch(e: Throwable) {
       throw enrichExceptionWithContext(e, "$key = $value") // enrich & rethrow
    }
}

There are different approaches to implement enrichExceptionWithContext function that you can choose based on your project's needs and preferences:

  • You can copy and recreate the exception similarly to how coroutines do it in debug mode, adding context information to the exception's message.
  • You can keep some side-channel WeakHashMap to associate exceptions with the additional context and use this map in your logging code.
  • You can create a separate exception to record the context in its message and use e.addSuppressed(...) to attach this context to the exception before rethrowing it.

Does it help?

Thanks for your suggestions!

  • I want this block to be transparent to the caller, so I can't wrap the exception in some other exception. Also I want to log the context in a structured way which means I can not transport it (or don't really want to) in the exception message as a string. This makes this approach not viable for me
  • Where would you put this WeakHashMap? Up the chain in the root scope context or something? Do you maybe have a link to an example where something similar is being done? I would be interested how this would work.
  • This sounds to me like a relatively simple (albeit a bit hacky) solution. Looking at the JVM implementation, I am not sure I can rely on the Throwable to always support adding suppressed exceptions though (because of the constructor overload with the enableSuppression param). Do you know if this is going to be a problem in reality?

Thanks again for your feedback, I appreciate it a lot.

I was not familiar with WeakHashMap, but after looking it up, I think I understand the concept.
This is what I came up with. I'd be interested in your opinion. I think it could be good to provide a solution to this problem in the docs of kotlinx.coroutines.slf4j because I can imagine others have the same need.

Couple of points:

  • I used a custom MDCContext instead of the one from kotlinx.coroutines.slf4j because I think typealias MDCContextMap = Map<String, String>? from the library should be typealias MDCContextMap = Map<String, String?>? because you can put null values into the java MDC
  • when recovering the captured context, I traverse the exception chain so that I can return the exception that is closest to the root cause. I did this so that wrapping the exception makes it still possible to get to the context. I _think_ this should also make it compatible with exception copying in the stacktrace recovery mechanism from kotlinx.coroutines.debug, right?
  • I also wrap the body in a coroutineScope in order to preserve the MDCContext() for async and launch children

Does this look conceptually sound to you?

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ThreadContextElement
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.withContext
import org.slf4j.MDC
import java.util.Collections
import java.util.WeakHashMap
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext

suspend fun <T> withMDCContext(vararg context: Pair<String, String?>, body: suspend CoroutineScope.() -> T): T {
    val mdc = appendToCurrentMDCContext(context.toMap())
    return captureExceptionMDCContext {
        withContext(MDCContext(mdc)) { coroutineScope(body) }
    }
}

suspend fun <T> captureExceptionMDCContext(body: suspend () -> T): T =
    try {
        body()
    } catch (e: Throwable) {
        MDC.getCopyOfContextMap()?.let { mdcContextHolder.putIfAbsent(e, it) }
        throw e
    }

val Throwable.capturedMDCContext: MDCContextMap
    get() {
        return capturedContext(this, null)
    }

private fun appendToCurrentMDCContext(context: Map<String, String?>): Map<String, String?> {
    return when (val current = MDC.getCopyOfContextMap()) {
        null -> context
        else -> current + context
    }
}

private val mdcContextHolder =
    Collections.synchronizedMap(WeakHashMap<Throwable, Map<String, String?>>())

private tailrec fun capturedContext(e: Throwable?, contextMap: MDCContextMap): MDCContextMap {
    if (e == null)
        return contextMap

    return capturedContext(e.cause, mdcContextHolder[e] ?: contextMap)
}

fun CoroutineContext.addMDCContext(vararg context: Pair<String, String?>): CoroutineContext {
    return this + MDCContext(appendToCurrentMDCContext(context.toMap()))
}

/**
 * The value of [MDC] context map.
 * See [MDC.getCopyOfContextMap].
 */
typealias MDCContextMap = Map<String, String?>?

/**
 * [MDC] context element for [CoroutineContext].
 *
 * Example:
 *
 * ```
 * MDC.put("kotlin", "rocks") // Put a value into the MDC context
 *
 * launch(MDCContext()) {
 *     logger.info { "..." }   // The MDC context contains the mapping here
 * }
 * ```
 *
 * Note that you cannot update MDC context from inside of the coroutine simply
 * using [MDC.put]. These updates are going to be lost on the next suspension and
 * reinstalled to the MDC context that was captured or explicitly specified in
 * [contextMap] when this object was created on the next resumption.
 * Use `withContext(MDCContext()) { ... }` to capture updated map of MDC keys and values
 * for the specified block of code.
 *
 * @param contextMap the value of [MDC] context map.
 * Default value is the copy of the current thread's context map that is acquired via
 * [MDC.getCopyOfContextMap].
 */
class MDCContext(
    /**
     * The value of [MDC] context map.
     */
    val contextMap: MDCContextMap = MDC.getCopyOfContextMap()
) : ThreadContextElement<MDCContextMap>, AbstractCoroutineContextElement(Key) {
    /**
     * Key of [MDCContext] in [CoroutineContext].
     */
    companion object Key : CoroutineContext.Key<MDCContext>

    /** @suppress */
    override fun updateThreadContext(context: CoroutineContext): MDCContextMap {
        val oldState = MDC.getCopyOfContextMap()
        setCurrent(contextMap)
        return oldState
    }

    /** @suppress */
    override fun restoreThreadContext(context: CoroutineContext, oldState: MDCContextMap) {
        setCurrent(oldState)
    }

    private fun setCurrent(contextMap: MDCContextMap) {
        if (contextMap == null) {
            MDC.clear()
        } else {
            MDC.setContextMap(contextMap)
        }
    }
}

You can just create a global, top-level private val context = WeakHashMap<Throwable, AdditionalContext> and addatch whatever additional context you'll need to your exception. However, please note that coroutines debug infrastructure will copy exception but it will keep the original ones as their cause, so at the point of logging you'll need to walk the chain of exception causes to retrieve all the context at the point the exception was thrown.

As for suppressed exceptions, in practice, all JVM exceptions support suppression.

Turns out I not only have to walk the chain of causes in the exception that I want to look up, but also in the exception I store in the WeakHashMap because both get adjusted by kontlinx.coroutines.debug. I do one at recording time and the other at lookup time because I had difficulties coming up with something that is not O(n^2) when trying to correlate the two chains just at lookup time and this was easier to write and read.

For reference, this is what I went with:

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ThreadContextElement
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.withContext
import org.slf4j.MDC
import java.util.Collections
import java.util.WeakHashMap
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext

/**
 * Run the [body] with the given [context] elements on the [MDC].
 * Upon returning, the MDC context will be reset to the original state.
 * This is also the case if the [body] throws. In that case, you can access the [MDC]
 * from the throw site via the [Throwable.capturedMDCContext] extension property.
 */
suspend fun <T> withMDCContext(vararg context: Pair<String, String?>, body: suspend CoroutineScope.() -> T): T {
    val mdc = appendToCurrentMDCContext(context.toMap())
    return withContext(MDCContext(mdc)) {
        captureExceptionMDCContext { coroutineScope(body) }
    }
}

/**
 * Run the [body] wrapped in a try/catch that captures the [MDC] context at the time of exception catching so
 * that it can later be retrieved via [Throwable.capturedMDCContext]
 */
suspend fun <T> captureExceptionMDCContext(body: suspend () -> T): T =
    try {
        body()
    } catch (e: Throwable) {
        MDC.getCopyOfContextMap()?.let { mdc ->
            e.exceptionChain.forEach { mdcContextHolder.putIfAbsent(it, mdc) }
        }
        throw e
    }

/**
 * The [MDC] context at the point in time when the exception was thrown (or null).
 * Traverses the exception chain and returns the first context captured closest to the root exception.
 */
val Throwable.capturedMDCContext: MDCContextMap
    get() {
        return exceptionChain
            .asSequence()
            .mapNotNull { mdcContextHolder[it] }
            .lastOrNull()
    }

/**
 * Append the [context] to the current [CoroutineContext]. In case of an exception, the context is not captured.
 * Pair with [captureExceptionMDCContext] if you want to capture context at point of exception.
 */
fun CoroutineContext.addMDCContext(vararg context: Pair<String, String?>): CoroutineContext {
    return this + MDCContext(appendToCurrentMDCContext(context.toMap()))
}

private fun appendToCurrentMDCContext(context: Map<String, String?>): Map<String, String?> {
    return when (val current = MDC.getCopyOfContextMap()) {
        null -> context
        else -> current + context
    }
}

private val mdcContextHolder =
    Collections.synchronizedMap(WeakHashMap<Throwable, Map<String, String?>>())

/**
 * Iterate over exception chain by following the cause trail
 */
private class ExceptionChainIterator(start: Throwable) : Iterator<Throwable> {
    var current = start

    override fun hasNext(): Boolean {
        return current.cause != null && current.cause !== current
    }

    override fun next(): Throwable {
        current = current.cause!!
        return current
    }
}

private val Throwable.exceptionChain: Iterator<Throwable>
    get() {
        return ExceptionChainIterator(this)
    }

/**
 * Custom version of the MDCContext from kotlinx.coroutines because we want to support String! from java as values.
 */

/**
 * The value of [MDC] context map.
 * See [MDC.getCopyOfContextMap].
 */
typealias MDCContextMap = Map<String, String?>?

/**
 * [MDC] context element for [CoroutineContext].
 *
 * Example:
 *
 * ```
 * MDC.put("kotlin", "rocks") // Put a value into the MDC context
 *
 * launch(MDCContext()) {
 *     logger.info { "..." }   // The MDC context contains the mapping here
 * }
 * ```
 *
 * Note that you cannot update MDC context from inside of the coroutine simply
 * using [MDC.put]. These updates are going to be lost on the next suspension and
 * reinstalled to the MDC context that was captured or explicitly specified in
 * [contextMap] when this object was created on the next resumption.
 * Use `withContext(MDCContext()) { ... }` to capture updated map of MDC keys and values
 * for the specified block of code.
 *
 * @param contextMap the value of [MDC] context map.
 * Default value is the copy of the current thread's context map that is acquired via
 * [MDC.getCopyOfContextMap].
 */
class MDCContext(
    /**
     * The value of [MDC] context map.
     */
    val contextMap: MDCContextMap = MDC.getCopyOfContextMap()
) : ThreadContextElement<MDCContextMap>, AbstractCoroutineContextElement(Key) {
    /**
     * Key of [MDCContext] in [CoroutineContext].
     */
    companion object Key : CoroutineContext.Key<MDCContext>

    /** @suppress */
    override fun updateThreadContext(context: CoroutineContext): MDCContextMap {
        val oldState = MDC.getCopyOfContextMap()
        setCurrent(contextMap)
        return oldState
    }

    /** @suppress */
    override fun restoreThreadContext(context: CoroutineContext, oldState: MDCContextMap) {
        setCurrent(oldState)
    }

    private fun setCurrent(contextMap: MDCContextMap) {
        if (contextMap == null) {
            MDC.clear()
        } else {
            MDC.setContextMap(contextMap)
        }
    }
}

Do you think it makes sense to add this to kotlinx.coroutines.slf4j in some way? Code or documentation?

I've looked at it cursorily and it looks good in general. A few things that I've noticed.

  1. You don't need all the ExceptionChainIterator. To get a sequence of all causes you can just write:
private val Throwable.exceptionChain: Sequence<Throwable>
    get() = generateSequence(ex) { e -> e.cause?.takeIf { it != e } }
  1. You don't need a copy of MDCContext class. You can just do an unchecked cast of your context map to the map that the standard MDCContext class takes.

  2. I did not really understand why you need to attach context to the whole chain in captureExceptionMDCContext function, since you only retrieve the first one anyway. It looks like attaching it only to the first exception should work just as well.

You don't need all the ExceptionChainIterator. To get a sequence of all causes you can just write:

Neat! Didn't know that.

You don't need a copy of MDCContext class. You can just do an unchecked cast of your context map to the map that the standard MDCContext class takes.

I see. I will keep it this way for now. MDCContext is simple enough that I can include my own instead of using the library (ktor does the same).

I did not really understand why you need to attach context to the whole chain in captureExceptionMDCContext function, since you only retrieve the first one anyway. It looks like attaching it only to the first exception should work just as well.

I thought so too, but I think the stacktrace recovery mechanism thwarts this. What I saw was that exception B with root cause A is stored in the WeakHashMap but then when I try to look up an unhandled exception, it will contain exception C with root cause A, so I think the rethrowing makes the stack trace recovery mechanism kick in and mess with the chain.
Storing the complete chain is still not necessary and I can just store the root cause instead of the head and that works as well.

Thanks again so much for your input, it was really helpful. What about my other question?
Do you think it makes sense to add this to kotlinx.coroutines.slf4j in some way? Code or documentation?

Sadly, it all looks too complicated and fragile (and thus domain-specific) to me. It is Ok when it is your project and you control all the code and aware of all the limitations, but that's not something a public library should do. I wish we could find a nicer solution, but I don't have any immediate ideas.

Yeah, I see that. I talked with a colleague and he and I agreed that MDC context is not that valuable to me if you can't access it in the exceptional case, so I thought others might at least find it valuable to get some pointer at how to approach this (e.g. in the docs).
But for me, this issue is solved, thanks again for your help!

There's one idea that came to my mind. Saving it for the record here for further discussion: https://github.com/Kotlin/kotlinx.coroutines/issues/2426

Was this page helpful?
0 / 5 - 0 ratings