""" ------------------------------------------------------------------------- File: train.py Description: DeepLabV3+ 모델 학습 스크립트 (4밴드 GeoTIFF) 사용법: python -m app.AI_modules.water_body_segmentation.train --dataset_dir /path/to/dataset/ --epochs 50 Author: 소지안 프로 Created: 2026-02-02 Last Modified: 2026-02-02 ------------------------------------------------------------------------- """ import os import argparse import tensorflow as tf from keras.callbacks import Callback, ModelCheckpoint from .config import ( TIF_DATASET_DIR, IMAGE_SIZE, BATCH_SIZE, LEARNING_RATE, EPOCHS, MODEL_SAVE_PATH, IN_CHANNELS ) from .data_loader_tif import load_dataset, create_dataset from .model import create_model from .pyplot_tif import show_predictions class ShowProgress(Callback): """학습 중 진행 상황을 시각화하는 콜백""" def __init__(self, val_images, val_masks, save_dir=None, interval=5): super().__init__() self.val_images = val_images self.val_masks = val_masks self.save_dir = save_dir self.interval = interval def on_epoch_end(self, epoch, logs=None): if (epoch + 1) % self.interval == 0: print(f"\n[Epoch {epoch + 1}] Saving validation predictions...") save_path = None if self.save_dir: save_path = os.path.join(self.save_dir, f"epoch_{epoch + 1}.png") show_predictions( self.val_images, self.val_masks, model=self.model, n_images=5, save_path=save_path ) def train( dataset_dir=TIF_DATASET_DIR, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, epochs=EPOCHS, model_save_path=MODEL_SAVE_PATH, in_channels=IN_CHANNELS, max_train_samples=None, max_val_samples=None, val_ratio=0.1, show_progress_interval=5 ): """ DeepLabV3+ 모델 학습 (TIF 데이터셋) Args: dataset_dir: TIF 데이터셋 경로 (image/, mask/ 폴더 포함) image_size: 이미지 크기 batch_size: 배치 크기 learning_rate: 학습률 epochs: 에폭 수 model_save_path: 모델 저장 경로 in_channels: 입력 채널 수 (4 = R,G,B,MNDWI) max_train_samples: 학습 데이터 최대 개수 (None이면 전체) max_val_samples: 검증 데이터 최대 개수 val_ratio: 검증 데이터 비율 (max_val_samples 미지정 시) show_progress_interval: 진행 상황 표시 간격 (에폭) Returns: 학습된 모델과 학습 히스토리 """ print("=" * 50) print("DeepLabV3+ Training (TIF)") print("=" * 50) print(f"Dataset Dir: {dataset_dir}") print(f"Image Size: {image_size}") print(f"Input Channels: {in_channels}") print(f"Batch Size: {batch_size}") print(f"Learning Rate: {learning_rate}") print(f"Epochs: {epochs}") print("=" * 50) # 전체 데이터 로드 print("\n[1/4] Loading dataset...") images, masks = load_dataset( dataset_dir, image_size=image_size, max_samples=max_train_samples ) total = len(images) # train/val 분리 if max_val_samples: n_val = max_val_samples else: n_val = max(1, int(total * val_ratio)) n_train = total - n_val train_images, train_masks = images[:n_train], masks[:n_train] val_images, val_masks = images[n_train:], masks[n_train:] print(f" Train: {train_images.shape[0]}개, Val: {val_images.shape[0]}개") # tf.data.Dataset 생성 train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_masks)) train_ds = train_ds.shuffle(n_train).batch(batch_size, drop_remainder=True) train_ds = train_ds.prefetch(tf.data.AUTOTUNE) val_ds = tf.data.Dataset.from_tensor_slices((val_images, val_masks)) val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) # 모델 생성 print("\n[2/4] Creating model...") model = create_model( image_size=image_size, learning_rate=learning_rate, compile_model=True, in_channels=in_channels ) print(f" Input: {model.input_shape}, Output: {model.output_shape}") # 저장 경로 디렉토리 생성 os.makedirs(os.path.dirname(model_save_path), exist_ok=True) # 결과 저장 디렉토리 progress_dir = os.path.join(os.path.dirname(model_save_path), 'training_progress') os.makedirs(progress_dir, exist_ok=True) # 콜백 설정 callbacks = [ ModelCheckpoint( model_save_path, save_best_only=True, monitor='val_loss', verbose=1 ), ShowProgress(val_images, val_masks, save_dir=progress_dir, interval=show_progress_interval) ] # 학습 print("\n[3/4] Training...") history = model.fit( train_ds, validation_data=val_ds, epochs=epochs, callbacks=callbacks ) # 최종 예측 결과 저장 print("\n[4/4] Saving final predictions...") show_predictions( val_images, val_masks, model, n_images=min(5, len(val_images)), save_path=os.path.join(progress_dir, "final_predictions.png") ) print("\n" + "=" * 50) print("Training completed!") print(f"Model saved to: {model_save_path}") print(f"Progress images: {progress_dir}") print("=" * 50) return model, history def main(): """CLI 엔트리포인트""" parser = argparse.ArgumentParser(description='Train DeepLabV3+ model (TIF)') parser.add_argument('--dataset_dir', type=str, default=TIF_DATASET_DIR, help='Path to TIF dataset directory') parser.add_argument('--image_size', type=int, default=IMAGE_SIZE, help='Image size') parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='Batch size') parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate') parser.add_argument('--epochs', type=int, default=EPOCHS, help='Number of epochs') parser.add_argument('--model_save_path', type=str, default=MODEL_SAVE_PATH, help='Path to save model') parser.add_argument('--max_samples', type=int, default=None, help='Max training samples to load') args = parser.parse_args() train( dataset_dir=args.dataset_dir, image_size=args.image_size, batch_size=args.batch_size, learning_rate=args.learning_rate, epochs=args.epochs, model_save_path=args.model_save_path, max_train_samples=args.max_samples ) if __name__ == '__main__': main()