Source code for ctlearn.core.attention

"""
This module defines the squeeze-excite blocks for channel-wise and/or spatial-wise attention mechanisms.
"""

import keras

__all__ = [
    "dual_squeeze_excite_block",
    "channel_squeeze_excite_block",
    "spatial_squeeze_excite_block",
]

[docs] def dual_squeeze_excite_block(inputs, ratio=16, name=None): """ A channel & spatial (dual) squeeze-excite block. This function creates a dual squeeze-excite block that combines both channel-wise and spatial-wise squeeze-excite mechanisms. The channel squeeze-excite block focuses on recalibrating the importance of each channel, while the spatial squeeze-excite block focuses on recalibrating the importance of each spatial location. Parameters ---------- inputs : keras.layers.Layer Input tensor to the squeeze-excite block. ratio : int Reduction ratio for the channel squeeze-excite block. Default is 16. name : str, optional Name for the squeeze-excite block. Default is None. Returns ------- keras.layers.Layer Output tensor for the squeeze-excite block. References ---------- - [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507) - [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579) """ cse = channel_squeeze_excite_block( inputs=inputs, ratio=ratio, name=name + "_cse", ) sse = spatial_squeeze_excite_block( inputs=inputs, name=name + "_sse" ) return keras.layers.Add(name=name + "_add")([cse, sse])
[docs] def channel_squeeze_excite_block(inputs, ratio=4, name=None): """ A channel-wise squeeze-excite block. This function creates a channel-wise squeeze-excite block that recalibrates the importance of each channel by using global average pooling followed by two dense layers. Parameters ---------- inputs : keras.layers.Layer Input tensor to the squeeze-excite block. ratio : int Reduction ratio for the squeeze-excite block. Default is 4. name : str, optional Name for the squeeze-excite block. Default is None. Returns ------- keras.layers.Layer Output tensor for the channel squeeze-excite block. """ # Temp fix for supporting keras2 & keras3 if int(keras.__version__.split(".")[0]) >= 3: filters = inputs.shape[-1] else: filters = inputs.get_shape().as_list()[-1] cse = keras.layers.GlobalAveragePooling2D( keepdims=True, name=name + "_avgpool" )(inputs) cse = keras.layers.Dense( units=filters // ratio, activation="relu", name=name + "_1_dense", )(cse) cse = keras.layers.Dense( units=filters, activation="sigmoid", name=name + "_2_dense" )(cse) return keras.layers.Multiply(name=name + "_mult")([inputs, cse])
[docs] def spatial_squeeze_excite_block(inputs, name=None): """ A spatial squeeze-excite block. This function creates a spatial squeeze-excite block that recalibrates the importance of each spatial location by using a convolutional layer with a sigmoid activation. Parameters ---------- inputs : keras.layers.Layer Input tensor to the squeeze-excite block. name : str, optional Name for the squeeze-excite block. Default is None. Returns ------- keras.layers.Layer Output tensor for the spatial squeeze-excite block. """ sse = keras.layers.Conv2D( filters=1, kernel_size=1, activation="sigmoid", name=name + "_spatial_conv", )(inputs) return keras.layers.Multiply(name=name + "_mult")([inputs, sse])