211 lines
6.7 KiB
Python
211 lines
6.7 KiB
Python
"""
|
|
-------------------------------------------------------------------------
|
|
File: train.py
|
|
Description: DeepLabV3+ 모델 학습 스크립트 (4밴드 GeoTIFF)
|
|
사용법:
|
|
python -m app.AI_modules.water_body_segmentation.train --dataset_dir /path/to/dataset/ --epochs 50
|
|
Author: 소지안 프로
|
|
Created: 2026-02-02
|
|
Last Modified: 2026-02-02
|
|
-------------------------------------------------------------------------
|
|
"""
|
|
import os
|
|
import argparse
|
|
import tensorflow as tf
|
|
from keras.callbacks import Callback, ModelCheckpoint
|
|
|
|
from .config import (
|
|
TIF_DATASET_DIR, IMAGE_SIZE, BATCH_SIZE,
|
|
LEARNING_RATE, EPOCHS, MODEL_SAVE_PATH, IN_CHANNELS
|
|
)
|
|
from .data_loader_tif import load_dataset, create_dataset
|
|
from .model import create_model
|
|
from .pyplot_tif import show_predictions
|
|
|
|
|
|
class ShowProgress(Callback):
|
|
"""학습 중 진행 상황을 시각화하는 콜백"""
|
|
|
|
def __init__(self, val_images, val_masks, save_dir=None, interval=5):
|
|
super().__init__()
|
|
self.val_images = val_images
|
|
self.val_masks = val_masks
|
|
self.save_dir = save_dir
|
|
self.interval = interval
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
if (epoch + 1) % self.interval == 0:
|
|
print(f"\n[Epoch {epoch + 1}] Saving validation predictions...")
|
|
save_path = None
|
|
if self.save_dir:
|
|
save_path = os.path.join(self.save_dir, f"epoch_{epoch + 1}.png")
|
|
show_predictions(
|
|
self.val_images,
|
|
self.val_masks,
|
|
model=self.model,
|
|
n_images=5,
|
|
save_path=save_path
|
|
)
|
|
|
|
|
|
def train(
|
|
dataset_dir=TIF_DATASET_DIR,
|
|
image_size=IMAGE_SIZE,
|
|
batch_size=BATCH_SIZE,
|
|
learning_rate=LEARNING_RATE,
|
|
epochs=EPOCHS,
|
|
model_save_path=MODEL_SAVE_PATH,
|
|
in_channels=IN_CHANNELS,
|
|
max_train_samples=None,
|
|
max_val_samples=None,
|
|
val_ratio=0.1,
|
|
show_progress_interval=5
|
|
):
|
|
"""
|
|
DeepLabV3+ 모델 학습 (TIF 데이터셋)
|
|
|
|
Args:
|
|
dataset_dir: TIF 데이터셋 경로 (image/, mask/ 폴더 포함)
|
|
image_size: 이미지 크기
|
|
batch_size: 배치 크기
|
|
learning_rate: 학습률
|
|
epochs: 에폭 수
|
|
model_save_path: 모델 저장 경로
|
|
in_channels: 입력 채널 수 (4 = R,G,B,MNDWI)
|
|
max_train_samples: 학습 데이터 최대 개수 (None이면 전체)
|
|
max_val_samples: 검증 데이터 최대 개수
|
|
val_ratio: 검증 데이터 비율 (max_val_samples 미지정 시)
|
|
show_progress_interval: 진행 상황 표시 간격 (에폭)
|
|
|
|
Returns:
|
|
학습된 모델과 학습 히스토리
|
|
"""
|
|
print("=" * 50)
|
|
print("DeepLabV3+ Training (TIF)")
|
|
print("=" * 50)
|
|
print(f"Dataset Dir: {dataset_dir}")
|
|
print(f"Image Size: {image_size}")
|
|
print(f"Input Channels: {in_channels}")
|
|
print(f"Batch Size: {batch_size}")
|
|
print(f"Learning Rate: {learning_rate}")
|
|
print(f"Epochs: {epochs}")
|
|
print("=" * 50)
|
|
|
|
# 전체 데이터 로드
|
|
print("\n[1/4] Loading dataset...")
|
|
images, masks = load_dataset(
|
|
dataset_dir,
|
|
image_size=image_size,
|
|
max_samples=max_train_samples
|
|
)
|
|
total = len(images)
|
|
|
|
# train/val 분리
|
|
if max_val_samples:
|
|
n_val = max_val_samples
|
|
else:
|
|
n_val = max(1, int(total * val_ratio))
|
|
n_train = total - n_val
|
|
|
|
train_images, train_masks = images[:n_train], masks[:n_train]
|
|
val_images, val_masks = images[n_train:], masks[n_train:]
|
|
print(f" Train: {train_images.shape[0]}개, Val: {val_images.shape[0]}개")
|
|
|
|
# tf.data.Dataset 생성
|
|
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_masks))
|
|
train_ds = train_ds.shuffle(n_train).batch(batch_size, drop_remainder=True)
|
|
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
|
|
|
|
val_ds = tf.data.Dataset.from_tensor_slices((val_images, val_masks))
|
|
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
|
|
|
# 모델 생성
|
|
print("\n[2/4] Creating model...")
|
|
model = create_model(
|
|
image_size=image_size,
|
|
learning_rate=learning_rate,
|
|
compile_model=True,
|
|
in_channels=in_channels
|
|
)
|
|
print(f" Input: {model.input_shape}, Output: {model.output_shape}")
|
|
|
|
# 저장 경로 디렉토리 생성
|
|
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
|
|
|
# 결과 저장 디렉토리
|
|
progress_dir = os.path.join(os.path.dirname(model_save_path), 'training_progress')
|
|
os.makedirs(progress_dir, exist_ok=True)
|
|
|
|
# 콜백 설정
|
|
callbacks = [
|
|
ModelCheckpoint(
|
|
model_save_path,
|
|
save_best_only=True,
|
|
monitor='val_loss',
|
|
verbose=1
|
|
),
|
|
ShowProgress(val_images, val_masks,
|
|
save_dir=progress_dir,
|
|
interval=show_progress_interval)
|
|
]
|
|
|
|
# 학습
|
|
print("\n[3/4] Training...")
|
|
history = model.fit(
|
|
train_ds,
|
|
validation_data=val_ds,
|
|
epochs=epochs,
|
|
callbacks=callbacks
|
|
)
|
|
|
|
# 최종 예측 결과 저장
|
|
print("\n[4/4] Saving final predictions...")
|
|
show_predictions(
|
|
val_images, val_masks, model,
|
|
n_images=min(5, len(val_images)),
|
|
save_path=os.path.join(progress_dir, "final_predictions.png")
|
|
)
|
|
|
|
print("\n" + "=" * 50)
|
|
print("Training completed!")
|
|
print(f"Model saved to: {model_save_path}")
|
|
print(f"Progress images: {progress_dir}")
|
|
print("=" * 50)
|
|
|
|
return model, history
|
|
|
|
|
|
def main():
|
|
"""CLI 엔트리포인트"""
|
|
parser = argparse.ArgumentParser(description='Train DeepLabV3+ model (TIF)')
|
|
parser.add_argument('--dataset_dir', type=str, default=TIF_DATASET_DIR,
|
|
help='Path to TIF dataset directory')
|
|
parser.add_argument('--image_size', type=int, default=IMAGE_SIZE,
|
|
help='Image size')
|
|
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
|
|
help='Batch size')
|
|
parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
|
|
help='Learning rate')
|
|
parser.add_argument('--epochs', type=int, default=EPOCHS,
|
|
help='Number of epochs')
|
|
parser.add_argument('--model_save_path', type=str, default=MODEL_SAVE_PATH,
|
|
help='Path to save model')
|
|
parser.add_argument('--max_samples', type=int, default=None,
|
|
help='Max training samples to load')
|
|
|
|
args = parser.parse_args()
|
|
|
|
train(
|
|
dataset_dir=args.dataset_dir,
|
|
image_size=args.image_size,
|
|
batch_size=args.batch_size,
|
|
learning_rate=args.learning_rate,
|
|
epochs=args.epochs,
|
|
model_save_path=args.model_save_path,
|
|
max_train_samples=args.max_samples
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|