goheung/app/AI_modules/water_body_segmentation/train.py
2026-02-02 19:07:53 +09:00

211 lines
6.7 KiB
Python

"""
-------------------------------------------------------------------------
File: train.py
Description: DeepLabV3+ 모델 학습 스크립트 (4밴드 GeoTIFF)
사용법:
python -m app.AI_modules.water_body_segmentation.train --dataset_dir /path/to/dataset/ --epochs 50
Author: 소지안 프로
Created: 2026-02-02
Last Modified: 2026-02-02
-------------------------------------------------------------------------
"""
import os
import argparse
import tensorflow as tf
from keras.callbacks import Callback, ModelCheckpoint
from .config import (
TIF_DATASET_DIR, IMAGE_SIZE, BATCH_SIZE,
LEARNING_RATE, EPOCHS, MODEL_SAVE_PATH, IN_CHANNELS
)
from .data_loader_tif import load_dataset, create_dataset
from .model import create_model
from .pyplot_tif import show_predictions
class ShowProgress(Callback):
"""학습 중 진행 상황을 시각화하는 콜백"""
def __init__(self, val_images, val_masks, save_dir=None, interval=5):
super().__init__()
self.val_images = val_images
self.val_masks = val_masks
self.save_dir = save_dir
self.interval = interval
def on_epoch_end(self, epoch, logs=None):
if (epoch + 1) % self.interval == 0:
print(f"\n[Epoch {epoch + 1}] Saving validation predictions...")
save_path = None
if self.save_dir:
save_path = os.path.join(self.save_dir, f"epoch_{epoch + 1}.png")
show_predictions(
self.val_images,
self.val_masks,
model=self.model,
n_images=5,
save_path=save_path
)
def train(
dataset_dir=TIF_DATASET_DIR,
image_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE,
epochs=EPOCHS,
model_save_path=MODEL_SAVE_PATH,
in_channels=IN_CHANNELS,
max_train_samples=None,
max_val_samples=None,
val_ratio=0.1,
show_progress_interval=5
):
"""
DeepLabV3+ 모델 학습 (TIF 데이터셋)
Args:
dataset_dir: TIF 데이터셋 경로 (image/, mask/ 폴더 포함)
image_size: 이미지 크기
batch_size: 배치 크기
learning_rate: 학습률
epochs: 에폭 수
model_save_path: 모델 저장 경로
in_channels: 입력 채널 수 (4 = R,G,B,MNDWI)
max_train_samples: 학습 데이터 최대 개수 (None이면 전체)
max_val_samples: 검증 데이터 최대 개수
val_ratio: 검증 데이터 비율 (max_val_samples 미지정 시)
show_progress_interval: 진행 상황 표시 간격 (에폭)
Returns:
학습된 모델과 학습 히스토리
"""
print("=" * 50)
print("DeepLabV3+ Training (TIF)")
print("=" * 50)
print(f"Dataset Dir: {dataset_dir}")
print(f"Image Size: {image_size}")
print(f"Input Channels: {in_channels}")
print(f"Batch Size: {batch_size}")
print(f"Learning Rate: {learning_rate}")
print(f"Epochs: {epochs}")
print("=" * 50)
# 전체 데이터 로드
print("\n[1/4] Loading dataset...")
images, masks = load_dataset(
dataset_dir,
image_size=image_size,
max_samples=max_train_samples
)
total = len(images)
# train/val 분리
if max_val_samples:
n_val = max_val_samples
else:
n_val = max(1, int(total * val_ratio))
n_train = total - n_val
train_images, train_masks = images[:n_train], masks[:n_train]
val_images, val_masks = images[n_train:], masks[n_train:]
print(f" Train: {train_images.shape[0]}개, Val: {val_images.shape[0]}")
# tf.data.Dataset 생성
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_masks))
train_ds = train_ds.shuffle(n_train).batch(batch_size, drop_remainder=True)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((val_images, val_masks))
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
# 모델 생성
print("\n[2/4] Creating model...")
model = create_model(
image_size=image_size,
learning_rate=learning_rate,
compile_model=True,
in_channels=in_channels
)
print(f" Input: {model.input_shape}, Output: {model.output_shape}")
# 저장 경로 디렉토리 생성
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
# 결과 저장 디렉토리
progress_dir = os.path.join(os.path.dirname(model_save_path), 'training_progress')
os.makedirs(progress_dir, exist_ok=True)
# 콜백 설정
callbacks = [
ModelCheckpoint(
model_save_path,
save_best_only=True,
monitor='val_loss',
verbose=1
),
ShowProgress(val_images, val_masks,
save_dir=progress_dir,
interval=show_progress_interval)
]
# 학습
print("\n[3/4] Training...")
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=callbacks
)
# 최종 예측 결과 저장
print("\n[4/4] Saving final predictions...")
show_predictions(
val_images, val_masks, model,
n_images=min(5, len(val_images)),
save_path=os.path.join(progress_dir, "final_predictions.png")
)
print("\n" + "=" * 50)
print("Training completed!")
print(f"Model saved to: {model_save_path}")
print(f"Progress images: {progress_dir}")
print("=" * 50)
return model, history
def main():
"""CLI 엔트리포인트"""
parser = argparse.ArgumentParser(description='Train DeepLabV3+ model (TIF)')
parser.add_argument('--dataset_dir', type=str, default=TIF_DATASET_DIR,
help='Path to TIF dataset directory')
parser.add_argument('--image_size', type=int, default=IMAGE_SIZE,
help='Image size')
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
help='Batch size')
parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
help='Learning rate')
parser.add_argument('--epochs', type=int, default=EPOCHS,
help='Number of epochs')
parser.add_argument('--model_save_path', type=str, default=MODEL_SAVE_PATH,
help='Path to save model')
parser.add_argument('--max_samples', type=int, default=None,
help='Max training samples to load')
args = parser.parse_args()
train(
dataset_dir=args.dataset_dir,
image_size=args.image_size,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
epochs=args.epochs,
model_save_path=args.model_save_path,
max_train_samples=args.max_samples
)
if __name__ == '__main__':
main()