Skip to content

pnsr和ssim指标 #22

@zyz-code

Description

@zyz-code

我用的预训练好的权重去做的指标计算,但是ssim指标结果不对,是我计算方式不对吗

处理完成: 1111 / 1111
平均 PSNR: 34.5106 ##(这个是对的)
平均 SSIM: 0.9545 ##(论文中是0.9713)

指标计算代码如下:
`import os
import cv2
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr_loss
from skimage.metrics import structural_similarity as ssim_loss
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
import multiprocessing

定义路径

GT_DIR = "/home/ubuntu/zyz/project/Datasets/test/GoPro/target"
PRED_DIR = "/home/ubuntu/zyz/project/EVSSM/results_final_1/GoPro_stage_1/GoPro"

def process_single_image(img_name):

path_gt = os.path.join(GT_DIR, img_name)
path_pred = os.path.join(PRED_DIR, img_name)

# 1. 检查文件是否存在
if not os.path.exists(path_gt):
    return None

# 2. 读取图像
img_gt = cv2.imread(path_gt)
img_pred = cv2.imread(path_pred)

if img_gt is None or img_pred is None:
    return None

# 3. 检查尺寸
if img_gt.shape != img_pred.shape:
    return None

# 4. 计算指标
# PSNR
psnr_val = psnr_loss(img_gt, img_pred, data_range=255)

# SSIM
# 处理 skimage 版本兼容性
try:
    ssim_val = ssim_loss(img_gt, img_pred, data_range=255, channel_axis=2)
except TypeError:
    ssim_val = ssim_loss(img_gt, img_pred, data_range=255, multichannel=True)

return psnr_val, ssim_val

def main():
if not os.path.exists(GT_DIR) or not os.path.exists(PRED_DIR):
print("路径不存在,请检查。")
return

# 获取图片列表
img_list = sorted([f for f in os.listdir(PRED_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
print(f"找到 {len(img_list)} 张图片,准备使用多进程计算...")

psnr_list = []
ssim_list = []

# 使用所有可用的 CPU 核心
num_workers = multiprocessing.cpu_count()
print(f"使用 CPU 核心数: {num_workers}")

# 开启进程池
with ProcessPoolExecutor(max_workers=num_workers) as executor:
    results = list(tqdm(executor.map(process_single_image, img_list), total=len(img_list)))
for res in results:
    if res is not None:
        psnr_list.append(res[0])
        ssim_list.append(res[1])

if len(psnr_list) == 0:
    print("没有成功计算任何图片,请检查文件名对应关系。")
    return

avg_psnr = np.mean(psnr_list)
avg_ssim = np.mean(ssim_list)

print("\n" + "="*30)
print(f"处理完成: {len(psnr_list)} / {len(img_list)}")
print(f"平均 PSNR: {avg_psnr:.4f}")
print(f"平均 SSIM: {avg_ssim:.4f}")
print("="*30)

if name == 'main':
main()`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions