112 lines
3.3 KiB
Python
112 lines
3.3 KiB
Python
"""
|
|
GeoTIFF 4밴드 수체 세그멘테이션 시각적 테스트
|
|
원본 RGB | 정답 마스크 | 예측 마스크 | 오버레이 비교
|
|
"""
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def _to_rgb(image):
|
|
"""4밴드 (R,G,B,MNDWI) 이미지에서 RGB 3채널 추출"""
|
|
return image[:, :, :3]
|
|
|
|
|
|
def show_predictions(images, masks, model, n_images=6, save_path=None):
|
|
"""
|
|
모델 예측 결과를 시각적으로 비교
|
|
|
|
Args:
|
|
images: (N, H, W, 4) 이미지 배열
|
|
masks: (N, H, W, 1) 정답 마스크 배열
|
|
model: 학습된 모델
|
|
n_images: 표시할 이미지 수
|
|
save_path: 저장 경로 (None이면 화면 표시)
|
|
"""
|
|
n = min(n_images, len(images))
|
|
indices = np.random.choice(len(images), n, replace=False)
|
|
|
|
fig, axes = plt.subplots(n, 4, figsize=(20, 5 * n))
|
|
if n == 1:
|
|
axes = axes[np.newaxis, :]
|
|
|
|
for row, idx in enumerate(indices):
|
|
image = images[idx]
|
|
mask = masks[idx]
|
|
pred = model.predict(image[np.newaxis, ...], verbose=0)[0]
|
|
|
|
rgb = _to_rgb(image)
|
|
mask_2d = mask[:, :, 0]
|
|
pred_2d = (pred[:, :, 0] > 0.5).astype(np.float32)
|
|
|
|
# 원본 RGB
|
|
axes[row, 0].imshow(np.clip(rgb, 0, 1))
|
|
axes[row, 0].set_title('Original RGB')
|
|
axes[row, 0].axis('off')
|
|
|
|
# 정답 마스크
|
|
axes[row, 1].imshow(mask_2d, cmap='Blues', vmin=0, vmax=1)
|
|
axes[row, 1].set_title('Ground Truth')
|
|
axes[row, 1].axis('off')
|
|
|
|
# 예측 마스크
|
|
axes[row, 2].imshow(pred_2d, cmap='Blues', vmin=0, vmax=1)
|
|
axes[row, 2].set_title('Prediction')
|
|
axes[row, 2].axis('off')
|
|
|
|
# 오버레이
|
|
axes[row, 3].imshow(np.clip(rgb, 0, 1))
|
|
axes[row, 3].imshow(pred_2d, cmap='Blues', alpha=0.4, vmin=0, vmax=1)
|
|
axes[row, 3].set_title('Overlay')
|
|
axes[row, 3].axis('off')
|
|
|
|
plt.tight_layout()
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
print(f"결과 저장: {save_path}")
|
|
else:
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
def show_dataset(images, masks, n_images=6, save_path=None):
|
|
"""
|
|
데이터셋 미리보기 (모델 없이)
|
|
|
|
Args:
|
|
images: (N, H, W, 4) 이미지 배열
|
|
masks: (N, H, W, 1) 마스크 배열
|
|
n_images: 표시할 이미지 수
|
|
save_path: 저장 경로 (None이면 화면 표시)
|
|
"""
|
|
n = min(n_images, len(images))
|
|
indices = np.random.choice(len(images), n, replace=False)
|
|
|
|
fig, axes = plt.subplots(n, 3, figsize=(15, 5 * n))
|
|
if n == 1:
|
|
axes = axes[np.newaxis, :]
|
|
|
|
for row, idx in enumerate(indices):
|
|
rgb = _to_rgb(images[idx])
|
|
mask_2d = masks[idx][:, :, 0]
|
|
|
|
axes[row, 0].imshow(np.clip(rgb, 0, 1))
|
|
axes[row, 0].set_title('Original RGB')
|
|
axes[row, 0].axis('off')
|
|
|
|
axes[row, 1].imshow(mask_2d, cmap='Blues', vmin=0, vmax=1)
|
|
axes[row, 1].set_title('Mask')
|
|
axes[row, 1].axis('off')
|
|
|
|
axes[row, 2].imshow(np.clip(rgb, 0, 1))
|
|
axes[row, 2].imshow(mask_2d, cmap='Blues', alpha=0.4, vmin=0, vmax=1)
|
|
axes[row, 2].set_title('Overlay')
|
|
axes[row, 2].axis('off')
|
|
|
|
plt.tight_layout()
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
print(f"결과 저장: {save_path}")
|
|
else:
|
|
plt.show()
|
|
plt.close()
|