goheung/app/AI_modules/water_body_segmentation/layers.py
2026-02-02 19:07:53 +09:00

92 lines
2.9 KiB
Python

# DeepLabV3+ Custom Layers
from keras.layers import ReLU
from keras.layers import Layer
from keras.layers import Conv2D
from keras.layers import Concatenate
from keras.layers import UpSampling2D
from keras.layers import AveragePooling2D
from keras.models import Sequential
from keras.layers import BatchNormalization
from .config import ASPP_FILTERS
class ConvBlock(Layer):
"""
Convolution Block: Conv2D -> BatchNormalization -> ReLU
Args:
filters: 출력 필터 수
kernel_size: 커널 크기
dilation_rate: Dilation rate (Atrous convolution용)
"""
def __init__(self, filters=ASPP_FILTERS, kernel_size=3, dilation_rate=1, **kwargs):
super(ConvBlock, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.net = Sequential([
Conv2D(
filters,
kernel_size=kernel_size,
padding='same',
dilation_rate=dilation_rate,
use_bias=False,
kernel_initializer='he_normal'
),
BatchNormalization(),
ReLU()
])
def call(self, X):
return self.net(X)
def get_config(self):
base_config = super().get_config()
return {
**base_config,
"filters": self.filters,
"kernel_size": self.kernel_size,
"dilation_rate": self.dilation_rate,
}
def AtrousSpatialPyramidPooling(X, filters=ASPP_FILTERS):
"""
Atrous Spatial Pyramid Pooling (ASPP) 모듈
다양한 dilation rate의 atrous convolution을 병렬로 적용하여
다중 스케일 컨텍스트를 캡처합니다.
Args:
X: 입력 텐서
filters: 각 브랜치의 출력 필터 수
Returns:
ASPP 처리된 텐서
"""
B, H, W, C = X.shape
# Image Pooling - 전역 컨텍스트 캡처
image_pool = AveragePooling2D(pool_size=(H, W), name="ASPP-AvgPool")(X)
image_pool = ConvBlock(filters=filters, kernel_size=1, name="ASPP-ImagePool-CB")(image_pool)
image_pool = UpSampling2D(
size=(H // image_pool.shape[1], W // image_pool.shape[2]),
name="ASPP-ImagePool-UpSample"
)(image_pool)
# Atrous Convolutions with different dilation rates
conv_1 = ConvBlock(filters=filters, kernel_size=1, dilation_rate=1, name="ASPP-CB-1")(X)
conv_6 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=6, name="ASPP-CB-6")(X)
conv_12 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=12, name="ASPP-CB-12")(X)
conv_18 = ConvBlock(filters=filters, kernel_size=3, dilation_rate=18, name="ASPP-CB-18")(X)
# Combine All branches
combined = Concatenate(name="ASPP-Combine")([image_pool, conv_1, conv_6, conv_12, conv_18])
processed = ConvBlock(filters=filters, kernel_size=1, name="ASPP-Net")(combined)
return processed