Skip to content

Middleware

Kstreams allows you to include middlewares for adding behavior to streams.

A middleware is a callable that works with every ConsumerRecord (CR) before and after it is processed by a specific stream. Middlewares also have access to the stream and send function.

  • It takes each CR that arrives to a kafka topic.
  • Then it can do something to the CR or run any needed code.
  • Then it passes the CR to be processed by another callable (other middleware or stream).
  • Once the CR is processed by the stream, the chain is "completed".
  • If there is code after the self.next_call(cr) then it will be executed.

Kstreams Middleware have the following protocol:

Bases: Protocol

Source code in kstreams/middleware/middleware.py
18
19
20
21
22
23
24
25
26
27
28
29
30
class MiddlewareProtocol(typing.Protocol):
    def __init__(
        self,
        *,
        next_call: types.NextMiddlewareCall,
        send: types.Send,
        stream: "Stream",
        **kwargs: typing.Any,
    ) -> None: ...  #  pragma: no cover

    async def __call__(
        self, cr: types.ConsumerRecord
    ) -> typing.Any: ...  #  pragma: no cover

Note

The __call__ method can return anything so previous calls can use the returned value. Make sure that the line return await self.next_call(cr) is in your method

Warning

Middlewares only work with the new Dependency Injection approach

Creating a middleware

To create a middleware you have to create a class that inherits from BaseMiddleware. Then, the method async def __call__ must be defined. Let's consider that we want to save the CR to elastic before it is processed:

import typing

from kstreams import ConsumerRecord, middleware

async def save_to_elastic(cr: ConsumerRecord) -> None:
    ...


class ElasticMiddleware(middleware.BaseMiddleware):
    async def __call__(self, cr: ConsumerRecord) -> typing.Any:
        # save to elastic before calling the next
        await save_to_elastic(cr)

        # the next call could be another middleware
        return await self.next_call(cr)

Then, we have to include the middleware:

from kstreams import ConsumerRecord, middleware

from .engine import stream_engine


middlewares = [middleware.Middleware(ElasticMiddleware)]

@stream_engine.stream("kstreams-topic", middlewares=middlewares)
    async def processor(cr: ConsumerRecord):
        ...

Note

The Middleware concept also applies for async generators (yield from a stream)

Adding extra configuration to middlewares

If you want to provide extra configuration to middleware you should override the init method with the extra options as keywargs and then call super().__init__(**kwargs)

Let's consider that we want to send an event to a spcific topic when a ValueError is raised inside a stream (Dead Letter Queue)

from kstreams import ConsumerRecord, types, Stream, middleware


class DLQMiddleware(middleware.BaseMiddleware):
    def __init__(self, *, topic: str, **kwargs) -> None:
        super().__init__(**kwargs)
        self.topic = topic

    async def __call__(self, cr: ConsumerRecord):
        try:
            return await self.next_call(cr)
        except ValueError:
            await self.send(self.topic, key=cr.key, value=cr.value)


# Create the middlewares
middlewares = [
    middleware.Middleware(
        DLQMiddleware, topic="kstreams-dlq-topic"
    )
]

@stream_engine.stream("kstreams-topic", middlewares=middlewares)
    async def processor(cr: ConsumerRecord):
        if cr.value == b"joker":
            raise ValueError("Joker received...")

Default Middleware

This is always the first Middleware in the middleware stack to catch any exception that might occur. Any exception raised when consuming events that is not handled by the end user will be handled by this ExceptionMiddleware executing the policy_error that was stablished.

Source code in kstreams/middleware/middleware.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class ExceptionMiddleware(BaseMiddleware):
    """
    This is always the first Middleware in the middleware stack
    to catch any exception that might occur. Any exception raised
    when consuming events that is not handled by the end user
    will be handled by this ExceptionMiddleware executing the
    policy_error that was stablished.
    """

    def __init__(
        self, *, engine: "StreamEngine", error_policy: StreamErrorPolicy, **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.engine = engine
        self.error_policy = error_policy

    async def __call__(self, cr: types.ConsumerRecord) -> typing.Any:
        try:
            return await self.next_call(cr)
        except errors.ConsumerStoppedError as exc:
            await self.cleanup_policy(exc)
        except Exception as exc:
            logger.exception(
                "Unhandled error occurred while listening to the stream. "
                f"Stream consuming from topics {self.stream.topics} CRASHED!!! \n\n "
            )
            if sys.version_info >= (3, 11):
                exc.add_note(f"Handler: {self.stream.func}")
                exc.add_note(f"Topics: {self.stream.topics}")

            await self.cleanup_policy(exc)

    async def cleanup_policy(self, exc: Exception) -> None:
        # always release the asyncio.Lock `is_processing` to
        # stop or restart properly the `stream`
        self.stream.is_processing.release()

        if self.error_policy == StreamErrorPolicy.RESTART:
            await self.stream.stop()
            logger.info(f"Restarting stream {self.stream}")
            await self.stream.start()
        elif self.error_policy == StreamErrorPolicy.STOP:
            await self.stream.stop()
            raise exc
        elif self.error_policy == StreamErrorPolicy.STOP_ENGINE:
            await self.engine.stop()
            raise exc
        else:
            # STOP_APPLICATION
            await self.engine.stop()
            signal.raise_signal(signal.SIGTERM)

        # acquire the asyncio.Lock `is_processing` again to resume the processing
        # and avoid `RuntimeError: Lock is not acquired.`
        await self.stream.is_processing.acquire()

Middleware chain

It is possible to add as many middlewares as you want to split and reuse business logic, however the downside is extra complexity and the code might become slower. The middleware order is important as they are evaluated in the order that were placed in the stream.

In the following example we are adding three middelwares in the following order: DLQMiddleware, ElasticMiddleware, and S3Middleware. The code chain execution will be:

sequenceDiagram
    autonumber
    ExceptionMiddleware->>DLQMiddleware: 
    Note left of ExceptionMiddleware: Event received
    alt No Processing Error
    DLQMiddleware->>ElasticMiddleware: 
    Note right of ElasticMiddleware: Store CR on Elastic
    ElasticMiddleware->>S3Middleware: 
    Note right of S3Middleware: Store CR on S3
    S3Middleware->>Stream: 
    Note right of Stream: CR processed
    Stream-->>S3Middleware: 
    S3Middleware-->>ElasticMiddleware: 
    ElasticMiddleware-->>DLQMiddleware: 
    DLQMiddleware-->>ExceptionMiddleware: 
    end
Multiple middlewares example
from kstreams import ConsumerRecord, Stream, middleware


class DLQMiddleware(middleware.BaseMiddleware):
    async def __call__(self, cr: ConsumerRecord):
        try:
            return await self.next_call(cr)
        except ValueError:
            await dlq(cr.value)


class ElasticMiddleware(middleware.BaseMiddleware):
    async def __call__(self, cr: ConsumerRecord):
        await save_to_elastic(cr.value)
        return await self.next_call(cr)


class S3Middleware(middleware.BaseMiddleware):
    async def __call__(self, cr: ConsumerRecord):
        await backup_to_s3(cr.value)
        return await self.next_call(cr)


middlewares = [
    middleware.Middleware(DLQMiddleware),
    middleware.Middleware(ElasticMiddleware),
    middleware.Middleware(S3Middleware),
]


@stream_engine.stream("kstreams-topic", middlewares=middlewares)
async def processor(cr: ConsumerRecord):
    if cr.value == event_2:
        raise ValueError("Error from stream...")
    await save_to_db(cr.value)

Note

In the example we can see that always the cr will be save into elastic and s3 regardless an error

Executing Code after the CR was processed

As mentioned in the introduction, it is possible to execute code after the CR is handled. To do this, we need to place code after next_call is called:

Execute code after CR is handled
from kstreams import ConsumerRecord, Stream, middleware


class DLQMiddleware(middleware.BaseMiddleware):
    async def __call__(self, cr: ConsumerRecord):
        try:
            return await self.next_call(cr)
        except ValueError:
            await dlq(cr.value)


class ElasticMiddleware(middleware.BaseMiddleware):
    async def __call__(self, cr: ConsumerRecord):
        return await self.next_call(cr)
        # This will be called after the whole chain has finished
        await save_to_elastic(cr.value)


middlewares = [
    middleware.Middleware(DLQMiddleware),
    middleware.Middleware(ElasticMiddleware),
]


@stream_engine.stream("kstreams-topic", middlewares=middlewares)
async def processor(cr: ConsumerRecord):
    if cr.value == event_2:
        raise ValueError("Error from stream...")
    await save_to_db(cr.value)

Note

In the example we can see that only if there is not an error the event is saved to elastic

Deserialization

To deserialize bytes into a different structure like dict middlewares are the preferred way to it. Examples:

Source code in examples/dataclasses-avroschema-example/dataclasses_avroschema_example/middlewares.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class AvroDeserializerMiddleware(middleware.BaseMiddleware):
    def __init__(self, *, model: AvroModel, **kwargs) -> None:
        super().__init__(**kwargs)
        self.model = model

    async def __call__(self, cr: ConsumerRecord):
        """
        Deserialize a payload to an AvroModel
        """
        if cr.value is not None:
            data = self.model.deserialize(cr.value)
            cr.value = data
        return await self.next_call(cr)
Source code in examples/confluent-example/confluent_example/middlewares.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class ConfluentMiddlewareDeserializer(
    middleware.BaseMiddleware, AsyncAvroMessageSerializer
):
    def __init__(
        self,
        *,
        schema_registry_client: AsyncSchemaRegistryClient,
        reader_schema: Optional[schema.AvroSchema] = None,
        return_record_name: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.schemaregistry_client = schema_registry_client
        self.reader_schema = reader_schema
        self.return_record_name = return_record_name
        self.id_to_decoder_func: Dict = {}
        self.id_to_writers: Dict = {}

    async def __call__(self, cr: ConsumerRecord):
        """
        Deserialize the event to a dict
        """
        data = await self.decode_message(cr.value)
        cr.value = data
        return await self.next_call(cr)