diff --git a/notifications_consumer/megabus/client_individual_consumer.py b/notifications_consumer/megabus/client_individual_consumer.py
index 2b4b9ee30ae30c81b1de92736ec521ebf1d114b8..a306fc13615ee8891ecc0193f722813dc2051618 100644
--- a/notifications_consumer/megabus/client_individual_consumer.py
+++ b/notifications_consumer/megabus/client_individual_consumer.py
@@ -1,5 +1,6 @@
 """Client Individual Consumer definition."""
 import logging
+import threading
 import time
 
 import stomp
@@ -23,13 +24,18 @@ class _ClientIndividualStompListener(_StompListener):
     in order for on_message implementations be capable of ack/nack in a specific connection.
     """
 
-    def __init__(self, listener: ClientIndividualListener, connection):
+    def __init__(self, listener: ClientIndividualListener, connection, on_disconnect_evt: threading.Event):
         _StompListener.__init__(self, listener)
         listener.set_connection(connection)
+        self._on_disconnect_evt = on_disconnect_evt
 
     def on_error(self, headers, body):
         self._listener.on_error(headers, body)
 
+    def on_disconnected(self):
+        logger.warning("_ClientIndividualStompListener on_disconnected")
+        self._on_disconnect_evt.set()
+
 
 class ClientIndividualConsumer(Consumer):
     """
@@ -57,6 +63,7 @@ class ClientIndividualConsumer(Consumer):
         self._listener_kwargs = listener_kwargs
         self._listeners = []  # type : List[ClientIndividualListener]
         self._hosts = []  # type : List[str]
+        self._on_disconnect_evt = threading.Event()
 
         # Auto connect must be called until _listener_kwargs are set
         if auto_connect:
@@ -71,23 +78,30 @@ class ClientIndividualConsumer(Consumer):
         """Create a new instance of the Listener class."""
         return self._external_listener(**self._listener_kwargs)
 
+    def __reconnect(self):
+        self._disconnect()
+        try:
+            self._simple_connect()
+        except ConnectFailedException as e:
+            self._external_listener.on_exception("CONNECTION", e)
+
     def _reconnection_loop(self):
         """Reconnection loop.
 
-        Reconnects only if host list changes.
+        Reconnects only if host list changes, or a disconnect was received.
         """
         while self._is_listening:
+            if self._on_disconnect_evt.is_set():
+                logger.warning("Disconnected event received - starting reconnection")
+                self.__reconnect()
+
             time.sleep(self._reconnection_interval)
             hosts = get_hosts(self._server, self._port, self._use_multiple_brokers)
             if set(hosts) == set(self._hosts):
                 continue
 
-            logger.debug("Host list change - reconnecting in order to do rebalancing")
-            self._disconnect()
-            try:
-                self._simple_connect()
-            except ConnectFailedException as e:
-                self._external_listener.on_exception("CONNECTION", e)
+            logger.warning("Host list change - reconnecting")
+            self.__reconnect()
 
     def _disconnect(self):
         """Disconnect and cleanup.
@@ -115,8 +129,10 @@ class ClientIndividualConsumer(Consumer):
             self._listeners.append(external_listener_instance)
             try:
                 try:
-                    connection = stomp.Connection([(host, self._port)])
-                    listener = _ClientIndividualStompListener(external_listener_instance, connection)
+                    connection = stomp.Connection([(host, self._port)], keepalive=True)
+                    listener = _ClientIndividualStompListener(
+                        external_listener_instance, connection, self._on_disconnect_evt
+                    )
 
                     if self._auth_method == "password":
                         connection.set_listener("", listener)
@@ -132,6 +148,7 @@ class ClientIndividualConsumer(Consumer):
                         connection.connect(wait=True)
                     connection.subscribe(destination=self._full_destination, headers=self._headers, ack=self._ack, id=1)
                     self._connections.append(connection)
+                    self._on_disconnect_evt.clear()
 
                 except stomp.exception.ConnectFailedException:
                     raise ConnectFailedException("stomp.exception.ConnectFailedException")