goheung/app/AI_modules/water_body_segmentation/data_loader_tif.py
2026-02-02 19:07:53 +09:00

166 lines
5.5 KiB
Python

"""
-------------------------------------------------------------------------
File: data_loader_tif.py
Description:
GeoTIFF 데이터 로더 (DeepLabV3+ 학습용)
4-band GeoTIFF (R, G, B, MNDWI) + 단일밴드 마스크를 로드하여
정규화, 리사이즈, augmentation, tf.data.Dataset 파이프라인 생성
Author: 원캉린, 소지안 프로
Created: 2026-02-02
Last Modified: 2026-02-02
-------------------------------------------------------------------------
"""
import os
import glob
import numpy as np
import rasterio
import tensorflow as tf
from tqdm import tqdm
from .config import IMAGE_SIZE, BATCH_SIZE, SHUFFLE_BUFFER
def _load_image(path: str, image_size: int = IMAGE_SIZE) -> np.ndarray:
"""4밴드 GeoTIFF 로드 → (H, W, 4) float32, 정규화 및 리사이즈 적용"""
with rasterio.open(path) as src:
image = src.read().transpose(1, 2, 0).astype(np.float32)
# 밴드별 min-max 정규화 (0~1)
for c in range(image.shape[2]):
band = image[:, :, c]
b_min, b_max = band.min(), band.max()
if b_max > b_min:
image[:, :, c] = (band - b_min) / (b_max - b_min)
else:
image[:, :, c] = 0.0
# 리사이즈
image = tf.image.resize(image, (image_size, image_size)).numpy()
return image
def _load_mask(path: str, image_size: int = IMAGE_SIZE) -> np.ndarray:
"""단일밴드 마스크 GeoTIFF 로드 → (H, W, 1) float32, 이진화 및 리사이즈 적용"""
with rasterio.open(path) as src:
mask = src.read(1).astype(np.float32)
mask = mask[..., np.newaxis] # (H, W, 1)
# 리사이즈
mask = tf.image.resize(mask, (image_size, image_size), method='nearest').numpy()
# 이진화 (0 또는 1)
mask = (mask > 0.5).astype(np.float32)
return mask
def load_dataset(dataset_dir: str, image_size: int = IMAGE_SIZE, max_samples: int = None):
"""
데이터셋 디렉토리에서 이미지/마스크 로드
Args:
dataset_dir: 'dataset/{scene_id}/' (image/, mask/ 폴더 포함)
image_size: 리사이즈 크기
max_samples: 최대 로드 개수 (None이면 전체)
Returns:
images: (N, H, W, 4) float32, 정규화됨
masks: (N, H, W, 1) float32, 이진 마스크
"""
image_paths = sorted(glob.glob(os.path.join(dataset_dir, "image", "*.tif")))
mask_paths = sorted(glob.glob(os.path.join(dataset_dir, "mask", "*.tif")))
assert len(image_paths) == len(mask_paths), \
f"이미지({len(image_paths)})와 마스크({len(mask_paths)}) 개수 불일치"
assert len(image_paths) > 0, f"이미지를 찾을 수 없음: {dataset_dir}/image/"
if max_samples is not None:
image_paths = image_paths[:max_samples]
mask_paths = mask_paths[:max_samples]
n = len(image_paths)
images = np.zeros((n, image_size, image_size, 4), dtype=np.float32)
masks = np.zeros((n, image_size, image_size, 1), dtype=np.float32)
for i, (img_path, msk_path) in tqdm(
enumerate(zip(image_paths, mask_paths)), total=n, desc="Loading TIF"
):
images[i] = _load_image(img_path, image_size)
masks[i] = _load_mask(msk_path, image_size)
return images, masks
def augment(image, mask):
"""Data augmentation (소규모 데이터셋용)"""
# 좌우 반전
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_left_right(image)
mask = tf.image.flip_left_right(mask)
# 상하 반전
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_up_down(image)
mask = tf.image.flip_up_down(mask)
# 90도 회전 (0~3회)
k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)
image = tf.image.rot90(image, k)
mask = tf.image.rot90(mask, k)
# 밝기 조절 (이미지만)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.clip_by_value(image, 0.0, 1.0)
return image, mask
def create_dataset(dataset_dir: str, image_size: int = IMAGE_SIZE,
batch_size: int = BATCH_SIZE, use_augmentation: bool = True,
max_samples: int = None, return_arrays: bool = False):
"""
TIF 데이터셋을 로드하여 tf.data.Dataset 파이프라인 생성
Args:
dataset_dir: 데이터셋 경로
image_size: 이미지 크기
batch_size: 배치 크기
use_augmentation: augmentation 적용 여부
return_arrays: True면 (dataset, images, masks) 반환
Returns:
tf.data.Dataset 또는 (dataset, images, masks) 튜플
"""
images, masks = load_dataset(dataset_dir, image_size, max_samples=max_samples)
dataset = tf.data.Dataset.from_tensor_slices((images, masks))
if use_augmentation:
dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(SHUFFLE_BUFFER)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
if return_arrays:
return dataset, images, masks
return dataset
if __name__ == "__main__":
dataset_dir = "dataset_256x256/S2B_MSIL2A_20231001T020659_R103_T52SCD_20241020T075116_tif"
images, masks = load_dataset(dataset_dir)
print(f"Images: {images.shape}, dtype={images.dtype}")
print(f"Masks: {masks.shape}, dtype={masks.dtype}")
print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Mask unique values: {np.unique(masks)}")
dataset = create_dataset(dataset_dir)
for batch_img, batch_msk in dataset.take(1):
print(f"Batch images: {batch_img.shape}")
print(f"Batch masks: {batch_msk.shape}")