goheung/app/AI_modules/water_body_segmentation/pyplot_tif.py
2026-02-02 19:07:53 +09:00

112 lines
3.3 KiB
Python

"""
GeoTIFF 4밴드 수체 세그멘테이션 시각적 테스트
원본 RGB | 정답 마스크 | 예측 마스크 | 오버레이 비교
"""
import numpy as np
import matplotlib.pyplot as plt
def _to_rgb(image):
"""4밴드 (R,G,B,MNDWI) 이미지에서 RGB 3채널 추출"""
return image[:, :, :3]
def show_predictions(images, masks, model, n_images=6, save_path=None):
"""
모델 예측 결과를 시각적으로 비교
Args:
images: (N, H, W, 4) 이미지 배열
masks: (N, H, W, 1) 정답 마스크 배열
model: 학습된 모델
n_images: 표시할 이미지 수
save_path: 저장 경로 (None이면 화면 표시)
"""
n = min(n_images, len(images))
indices = np.random.choice(len(images), n, replace=False)
fig, axes = plt.subplots(n, 4, figsize=(20, 5 * n))
if n == 1:
axes = axes[np.newaxis, :]
for row, idx in enumerate(indices):
image = images[idx]
mask = masks[idx]
pred = model.predict(image[np.newaxis, ...], verbose=0)[0]
rgb = _to_rgb(image)
mask_2d = mask[:, :, 0]
pred_2d = (pred[:, :, 0] > 0.5).astype(np.float32)
# 원본 RGB
axes[row, 0].imshow(np.clip(rgb, 0, 1))
axes[row, 0].set_title('Original RGB')
axes[row, 0].axis('off')
# 정답 마스크
axes[row, 1].imshow(mask_2d, cmap='Blues', vmin=0, vmax=1)
axes[row, 1].set_title('Ground Truth')
axes[row, 1].axis('off')
# 예측 마스크
axes[row, 2].imshow(pred_2d, cmap='Blues', vmin=0, vmax=1)
axes[row, 2].set_title('Prediction')
axes[row, 2].axis('off')
# 오버레이
axes[row, 3].imshow(np.clip(rgb, 0, 1))
axes[row, 3].imshow(pred_2d, cmap='Blues', alpha=0.4, vmin=0, vmax=1)
axes[row, 3].set_title('Overlay')
axes[row, 3].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"결과 저장: {save_path}")
else:
plt.show()
plt.close()
def show_dataset(images, masks, n_images=6, save_path=None):
"""
데이터셋 미리보기 (모델 없이)
Args:
images: (N, H, W, 4) 이미지 배열
masks: (N, H, W, 1) 마스크 배열
n_images: 표시할 이미지 수
save_path: 저장 경로 (None이면 화면 표시)
"""
n = min(n_images, len(images))
indices = np.random.choice(len(images), n, replace=False)
fig, axes = plt.subplots(n, 3, figsize=(15, 5 * n))
if n == 1:
axes = axes[np.newaxis, :]
for row, idx in enumerate(indices):
rgb = _to_rgb(images[idx])
mask_2d = masks[idx][:, :, 0]
axes[row, 0].imshow(np.clip(rgb, 0, 1))
axes[row, 0].set_title('Original RGB')
axes[row, 0].axis('off')
axes[row, 1].imshow(mask_2d, cmap='Blues', vmin=0, vmax=1)
axes[row, 1].set_title('Mask')
axes[row, 1].axis('off')
axes[row, 2].imshow(np.clip(rgb, 0, 1))
axes[row, 2].imshow(mask_2d, cmap='Blues', alpha=0.4, vmin=0, vmax=1)
axes[row, 2].set_title('Overlay')
axes[row, 2].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"결과 저장: {save_path}")
else:
plt.show()
plt.close()