173 lines
4.2 KiB
Python
173 lines
4.2 KiB
Python
# DeepLabV3+ Utilities
|
|
import io
|
|
import base64
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def show_images(data, model=None, explain=False, n_images=5, SIZE=(15, 5)):
|
|
"""
|
|
이미지와 마스크 시각화
|
|
|
|
Args:
|
|
data: tf.data.Dataset (images, masks) 형태
|
|
model: 예측에 사용할 모델 (선택)
|
|
explain: True면 예측 결과도 표시
|
|
n_images: 표시할 이미지 수
|
|
SIZE: 그림 크기 (width, height)
|
|
"""
|
|
# Get batch from dataset
|
|
for images, masks in data.take(1):
|
|
images = images.numpy()
|
|
masks = masks.numpy()
|
|
|
|
# Limit number of images
|
|
n_images = min(n_images, len(images))
|
|
|
|
if explain and model is not None:
|
|
# Show Image, Mask, and Prediction
|
|
fig, axes = plt.subplots(n_images, 3, figsize=SIZE)
|
|
predictions = model.predict(images[:n_images])
|
|
|
|
for i in range(n_images):
|
|
# Original Image
|
|
axes[i, 0].imshow(images[i])
|
|
axes[i, 0].set_title("Image")
|
|
axes[i, 0].axis('off')
|
|
|
|
# Ground Truth Mask
|
|
axes[i, 1].imshow(masks[i])
|
|
axes[i, 1].set_title("Mask (GT)")
|
|
axes[i, 1].axis('off')
|
|
|
|
# Predicted Mask
|
|
axes[i, 2].imshow(predictions[i])
|
|
axes[i, 2].set_title("Prediction")
|
|
axes[i, 2].axis('off')
|
|
else:
|
|
# Show only Image and Mask
|
|
fig, axes = plt.subplots(n_images, 2, figsize=SIZE)
|
|
|
|
for i in range(n_images):
|
|
# Original Image
|
|
axes[i, 0].imshow(images[i])
|
|
axes[i, 0].set_title("Image")
|
|
axes[i, 0].axis('off')
|
|
|
|
# Ground Truth Mask
|
|
axes[i, 1].imshow(masks[i])
|
|
axes[i, 1].set_title("Mask")
|
|
axes[i, 1].axis('off')
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
|
|
def show_single_prediction(model, image):
|
|
"""
|
|
단일 이미지 예측 시각화
|
|
|
|
Args:
|
|
model: 학습된 DeepLabV3+ 모델
|
|
image: 입력 이미지 (numpy array)
|
|
"""
|
|
# Ensure correct shape
|
|
if len(image.shape) == 3:
|
|
image = np.expand_dims(image, axis=0)
|
|
|
|
# Predict
|
|
prediction = model.predict(image)[0]
|
|
|
|
# Visualize
|
|
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
|
|
|
axes[0].imshow(image[0])
|
|
axes[0].set_title("Input Image")
|
|
axes[0].axis('off')
|
|
|
|
axes[1].imshow(prediction)
|
|
axes[1].set_title("Prediction")
|
|
axes[1].axis('off')
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
|
|
def prediction_to_base64(model, image):
|
|
"""
|
|
예측 결과를 base64 이미지로 변환
|
|
|
|
Args:
|
|
model: 학습된 DeepLabV3+ 모델
|
|
image: 입력 이미지 (numpy array)
|
|
|
|
Returns:
|
|
base64 인코딩된 이미지 문자열
|
|
"""
|
|
# Ensure correct shape
|
|
if len(image.shape) == 3:
|
|
image = np.expand_dims(image, axis=0)
|
|
|
|
# Predict
|
|
prediction = model.predict(image)[0]
|
|
|
|
# Create figure
|
|
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
|
|
|
axes[0].imshow(image[0])
|
|
axes[0].set_title("Input Image")
|
|
axes[0].axis('off')
|
|
|
|
axes[1].imshow(prediction)
|
|
axes[1].set_title("Water Body Prediction")
|
|
axes[1].axis('off')
|
|
|
|
plt.tight_layout()
|
|
|
|
# Convert to base64
|
|
buf = io.BytesIO()
|
|
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
|
|
buf.seek(0)
|
|
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
|
plt.close(fig)
|
|
|
|
return img_base64
|
|
|
|
|
|
def mask_to_binary(mask, threshold=0.5):
|
|
"""
|
|
마스크를 이진화
|
|
|
|
Args:
|
|
mask: 예측 마스크
|
|
threshold: 이진화 임계값
|
|
|
|
Returns:
|
|
이진화된 마스크
|
|
"""
|
|
return (mask > threshold).astype(np.uint8)
|
|
|
|
|
|
def calculate_iou(pred_mask, true_mask, threshold=0.5):
|
|
"""
|
|
IoU (Intersection over Union) 계산
|
|
|
|
Args:
|
|
pred_mask: 예측 마스크
|
|
true_mask: 실제 마스크
|
|
threshold: 이진화 임계값
|
|
|
|
Returns:
|
|
IoU 값
|
|
"""
|
|
pred_binary = mask_to_binary(pred_mask, threshold)
|
|
true_binary = mask_to_binary(true_mask, threshold)
|
|
|
|
intersection = np.logical_and(pred_binary, true_binary).sum()
|
|
union = np.logical_or(pred_binary, true_binary).sum()
|
|
|
|
if union == 0:
|
|
return 1.0
|
|
|
|
return intersection / union
|