92 lines
2.9 KiB
Python
92 lines
2.9 KiB
Python
# DeepLabV3+ Custom Layers
|
|
from keras.layers import ReLU
|
|
from keras.layers import Layer
|
|
from keras.layers import Conv2D
|
|
from keras.layers import Concatenate
|
|
from keras.layers import UpSampling2D
|
|
from keras.layers import AveragePooling2D
|
|
from keras.models import Sequential
|
|
from keras.layers import BatchNormalization
|
|
|
|
from .config import ASPP_FILTERS
|
|
|
|
|
|
class ConvBlock(Layer):
|
|
"""
|
|
Convolution Block: Conv2D -> BatchNormalization -> ReLU
|
|
|
|
Args:
|
|
filters: 출력 필터 수
|
|
kernel_size: 커널 크기
|
|
dilation_rate: Dilation rate (Atrous convolution용)
|
|
"""
|
|
|
|
def __init__(self, filters=ASPP_FILTERS, kernel_size=3, dilation_rate=1, **kwargs):
|
|
super(ConvBlock, self).__init__(**kwargs)
|
|
|
|
self.filters = filters
|
|
self.kernel_size = kernel_size
|
|
self.dilation_rate = dilation_rate
|
|
|
|
self.net = Sequential([
|
|
Conv2D(
|
|
filters,
|
|
kernel_size=kernel_size,
|
|
padding='same',
|
|
dilation_rate=dilation_rate,
|
|
use_bias=False,
|
|
kernel_initializer='he_normal'
|
|
),
|
|
BatchNormalization(),
|
|
ReLU()
|
|
])
|
|
|
|
def call(self, X):
|
|
return self.net(X)
|
|
|
|
def get_config(self):
|
|
base_config = super().get_config()
|
|
return {
|
|
**base_config,
|
|
"filters": self.filters,
|
|
"kernel_size": self.kernel_size,
|
|
"dilation_rate": self.dilation_rate,
|
|
}
|
|
|
|
|
|
def AtrousSpatialPyramidPooling(X, filters=ASPP_FILTERS):
|
|
"""
|
|
Atrous Spatial Pyramid Pooling (ASPP) 모듈
|
|
|
|
다양한 dilation rate의 atrous convolution을 병렬로 적용하여
|
|
다중 스케일 컨텍스트를 캡처합니다.
|
|
|
|
Args:
|
|
X: 입력 텐서
|
|
filters: 각 브랜치의 출력 필터 수
|
|
|
|
Returns:
|
|
ASPP 처리된 텐서
|
|
"""
|
|
B, H, W, C = X.shape
|
|
|
|
# Image Pooling - 전역 컨텍스트 캡처
|
|
image_pool = AveragePooling2D(pool_size=(H, W), name="ASPP-AvgPool")(X)
|
|
image_pool = ConvBlock(filters=filters, kernel_size=1, name="ASPP-ImagePool-CB")(image_pool)
|
|
image_pool = UpSampling2D(
|
|
size=(H // image_pool.shape[1], W // image_pool.shape[2]),
|
|
name="ASPP-ImagePool-UpSample"
|
|
)(image_pool)
|
|
|
|
# Atrous Convolutions with different dilation rates
|
|
conv_1 = ConvBlock(filters=filters, kernel_size=1, dilation_rate=1, name="ASPP-CB-1")(X)
|
|
conv_6 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=6, name="ASPP-CB-6")(X)
|
|
conv_12 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=12, name="ASPP-CB-12")(X)
|
|
conv_18 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=18, name="ASPP-CB-18")(X)
|
|
|
|
# Combine All branches
|
|
combined = Concatenate(name="ASPP-Combine")([image_pool, conv_1, conv_6, conv_12, conv_18])
|
|
processed = ConvBlock(filters=filters, kernel_size=1, name="ASPP-Net")(combined)
|
|
|
|
return processed
|