""" GeoTIFF 4밴드 수체 세그멘테이션 시각적 테스트 원본 RGB | 정답 마스크 | 예측 마스크 | 오버레이 비교 """ import numpy as np import matplotlib.pyplot as plt def _to_rgb(image): """4밴드 (R,G,B,MNDWI) 이미지에서 RGB 3채널 추출""" return image[:, :, :3] def show_predictions(images, masks, model, n_images=6, save_path=None): """ 모델 예측 결과를 시각적으로 비교 Args: images: (N, H, W, 4) 이미지 배열 masks: (N, H, W, 1) 정답 마스크 배열 model: 학습된 모델 n_images: 표시할 이미지 수 save_path: 저장 경로 (None이면 화면 표시) """ n = min(n_images, len(images)) indices = np.random.choice(len(images), n, replace=False) fig, axes = plt.subplots(n, 4, figsize=(20, 5 * n)) if n == 1: axes = axes[np.newaxis, :] for row, idx in enumerate(indices): image = images[idx] mask = masks[idx] pred = model.predict(image[np.newaxis, ...], verbose=0)[0] rgb = _to_rgb(image) mask_2d = mask[:, :, 0] pred_2d = (pred[:, :, 0] > 0.5).astype(np.float32) # 원본 RGB axes[row, 0].imshow(np.clip(rgb, 0, 1)) axes[row, 0].set_title('Original RGB') axes[row, 0].axis('off') # 정답 마스크 axes[row, 1].imshow(mask_2d, cmap='Blues', vmin=0, vmax=1) axes[row, 1].set_title('Ground Truth') axes[row, 1].axis('off') # 예측 마스크 axes[row, 2].imshow(pred_2d, cmap='Blues', vmin=0, vmax=1) axes[row, 2].set_title('Prediction') axes[row, 2].axis('off') # 오버레이 axes[row, 3].imshow(np.clip(rgb, 0, 1)) axes[row, 3].imshow(pred_2d, cmap='Blues', alpha=0.4, vmin=0, vmax=1) axes[row, 3].set_title('Overlay') axes[row, 3].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"결과 저장: {save_path}") else: plt.show() plt.close() def show_dataset(images, masks, n_images=6, save_path=None): """ 데이터셋 미리보기 (모델 없이) Args: images: (N, H, W, 4) 이미지 배열 masks: (N, H, W, 1) 마스크 배열 n_images: 표시할 이미지 수 save_path: 저장 경로 (None이면 화면 표시) """ n = min(n_images, len(images)) indices = np.random.choice(len(images), n, replace=False) fig, axes = plt.subplots(n, 3, figsize=(15, 5 * n)) if n == 1: axes = axes[np.newaxis, :] for row, idx in enumerate(indices): rgb = _to_rgb(images[idx]) mask_2d = masks[idx][:, :, 0] axes[row, 0].imshow(np.clip(rgb, 0, 1)) axes[row, 0].set_title('Original RGB') axes[row, 0].axis('off') axes[row, 1].imshow(mask_2d, cmap='Blues', vmin=0, vmax=1) axes[row, 1].set_title('Mask') axes[row, 1].axis('off') axes[row, 2].imshow(np.clip(rgb, 0, 1)) axes[row, 2].imshow(mask_2d, cmap='Blues', alpha=0.4, vmin=0, vmax=1) axes[row, 2].set_title('Overlay') axes[row, 2].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"결과 저장: {save_path}") else: plt.show() plt.close()