# DeepLabV3+ Custom Layers from keras.layers import ReLU from keras.layers import Layer from keras.layers import Conv2D from keras.layers import Concatenate from keras.layers import UpSampling2D from keras.layers import AveragePooling2D from keras.models import Sequential from keras.layers import BatchNormalization from .config import ASPP_FILTERS class ConvBlock(Layer): """ Convolution Block: Conv2D -> BatchNormalization -> ReLU Args: filters: 출력 필터 수 kernel_size: 커널 크기 dilation_rate: Dilation rate (Atrous convolution용) """ def __init__(self, filters=ASPP_FILTERS, kernel_size=3, dilation_rate=1, **kwargs): super(ConvBlock, self).__init__(**kwargs) self.filters = filters self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.net = Sequential([ Conv2D( filters, kernel_size=kernel_size, padding='same', dilation_rate=dilation_rate, use_bias=False, kernel_initializer='he_normal' ), BatchNormalization(), ReLU() ]) def call(self, X): return self.net(X) def get_config(self): base_config = super().get_config() return { **base_config, "filters": self.filters, "kernel_size": self.kernel_size, "dilation_rate": self.dilation_rate, } def AtrousSpatialPyramidPooling(X, filters=ASPP_FILTERS): """ Atrous Spatial Pyramid Pooling (ASPP) 모듈 다양한 dilation rate의 atrous convolution을 병렬로 적용하여 다중 스케일 컨텍스트를 캡처합니다. Args: X: 입력 텐서 filters: 각 브랜치의 출력 필터 수 Returns: ASPP 처리된 텐서 """ B, H, W, C = X.shape # Image Pooling - 전역 컨텍스트 캡처 image_pool = AveragePooling2D(pool_size=(H, W), name="ASPP-AvgPool")(X) image_pool = ConvBlock(filters=filters, kernel_size=1, name="ASPP-ImagePool-CB")(image_pool) image_pool = UpSampling2D( size=(H // image_pool.shape[1], W // image_pool.shape[2]), name="ASPP-ImagePool-UpSample" )(image_pool) # Atrous Convolutions with different dilation rates conv_1 = ConvBlock(filters=filters, kernel_size=1, dilation_rate=1, name="ASPP-CB-1")(X) conv_6 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=6, name="ASPP-CB-6")(X) conv_12 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=12, name="ASPP-CB-12")(X) conv_18 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=18, name="ASPP-CB-18")(X) # Combine All branches combined = Concatenate(name="ASPP-Combine")([image_pool, conv_1, conv_6, conv_12, conv_18]) processed = ConvBlock(filters=filters, kernel_size=1, name="ASPP-Net")(combined) return processed