import numpy as np

# --- 预先计算并声明转换矩阵 (常量) ---
# 采用更清晰的命名规范，并预先计算逆矩阵以提高效率

# 矩阵：从 XYZ (D65) 转换到 LMS
XYZ_TO_LMS_MATRIX = np.array([
    [0.400227052325915, 0.707583738918745, -0.080809279274243],
    [-0.226301727628935, 1.165334160925190, 0.045699378859811],
    [0.0, 0.0, 0.918416046041877]
])
# 逆矩阵：从 LMS 返回 XYZ (D65)
LMS_TO_XYZ_MATRIX = np.linalg.inv(XYZ_TO_LMS_MATRIX)

# 矩阵：从非线性 LMS (LMS_TM) 转换到 Iab
LMS_TO_IAB_MATRIX = np.array([
    [65.57377049, 32.78688525, 1.63934426],
    [430.0, -470.0, 40.0],
    [49.0, 49.0, -98.0]
])
# 逆矩阵：从 Iab 返回非线性 LMS (LMS_TM)
IAB_TO_LMS_MATRIX = np.linalg.inv(LMS_TO_IAB_MATRIX)

CAT16_XYZ_TO_RGB = np.array([
    [0.401288, 0.650173, -0.051461],
    [-0.250268, 1.204414, 0.045854],
    [-0.002079, 0.048952, 0.953127]
])

# --- sCAM 模型参数 (模块化) ---
SCAM_SURROUND_PARAMS = {
    'avg':  {'F': 1.0, 'c': 0.52, 'Fm': 1.0},
    'dim':  {'F': 0.9, 'c': 0.50, 'Fm': 0.95},
    'dark': {'F': 0.8, 'c': 0.39, 'Fm': 0.85},
}
HUE_DATA = {
    'h_i': np.array([16.5987, 80.2763, 157.779, 219.7174, 376.5987]),
    'e_i': np.array([0.7, 0.6, 1.2, 0.9, 0.7]),
    'H_i': np.array([0, 100, 200, 300, 400])
}

# 预先计算逆矩阵，避免在函数中重复计算
CAT16_RGB_TO_XYZ = np.array([
    [1.86206786, -1.01125463,  0.14918677],
    [0.38752654,  0.62144744, -0.00897398],
    [-0.01584150, -0.03412294,  1.04996444]
])

CAT16_XYZ_TO_RGB_T = CAT16_XYZ_TO_RGB.T
CAT16_RGB_TO_XYZ_T = CAT16_RGB_TO_XYZ.T

XYZW_D65 = np.array([95.047, 100.0, 108.883])  # D65 白点的 XYZ 值

def xyz_to_sucs(xyz):
    """
    将 XYZ 颜色空间转换为 sUCS (Iab) 颜色空间。
    Args:
        xyz (array_like): 一个或多个XYZ值。输入假定为D65光源，值域为 [0, 100]。
                          可以是单个向量 [X, Y, Z] 或 Nx3 数组。
    Returns:
        numpy.ndarray: 对应的 sUCS (Iab) 值。
    """
    # 确保输入是Numpy数组
    xyz = np.asarray(xyz, dtype=np.float64)
    is_single_vector = (xyz.ndim == 1)
    if is_single_vector:
        xyz = xyz.reshape(1, -1)

    # 1. XYZ 归一化 (XYZ / 100)
    xyz_scaled = xyz / 100.0

    # 2. XYZ -> LMS
    lms = xyz_scaled @ XYZ_TO_LMS_MATRIX.T

    # 3. 非线性压缩 (γ = 0.43)
    lms_tm = np.sign(lms) * (np.abs(lms)**0.43)

    # 4. LMS -> Iab
    iab = lms_tm @ LMS_TO_IAB_MATRIX.T

    # 5. 计算最终的 a_1, b_1 (优化实现)
    I = iab[:, 0]
    a = iab[:, 1]
    b = iab[:, 2]
    
    C = np.sqrt(a**2 + b**2)
    C_1 = np.log1p(0.0447 * C) / 0.0252
    
    # 避免计算角度h，直接代数计算
    a_1 = np.divide(C_1 * a, C, out=np.zeros_like(C), where=C != 0)
    b_1 = np.divide(C_1 * b, C, out=np.zeros_like(C), where=C != 0)

    # 6. 合并结果
    iab_1 = np.stack((I, a_1, b_1), axis=1)

    return iab_1.flatten() if is_single_vector else iab_1

def sucs_to_xyz(iab_1):
    """
    将 sUCS (Iab) 颜色空间转换回 XYZ 颜色空间。
    Args:
        iab_1 (array_like): 一个或多个sUCS值 (I, a_1, b_1)。
                           可以是单个向量 [I, a_1, b_1] 或 Nx3 数组。
    Returns:
        numpy.ndarray: 对应的 XYZ (D65) 值，值域为 [0, 100]。
    """
    # 确保输入是Numpy数组
    iab_1 = np.asarray(iab_1, dtype=np.float64)
    is_single_vector = (iab_1.ndim == 1)
    if is_single_vector:
        iab_1 = iab_1.reshape(1, -1)
    
    # 1. 从 a_1, b_1 恢复 a, b (优化实现)
    I = iab_1[:, 0]
    a_1 = iab_1[:, 1]
    b_1 = iab_1[:, 2]

    C_1 = np.sqrt(a_1**2 + b_1**2)
    # 使用 np.expm1(x) 替换 exp(x)-1 以提高精度
    C = np.expm1(0.0252 * C_1) / 0.0447
    
    # 避免计算角度h，直接代数计算
    a = np.divide(C * a_1, C_1, out=np.zeros_like(C_1), where=C_1 != 0)
    b = np.divide(C * b_1, C_1, out=np.zeros_like(C_1), where=C_1 != 0)
    
    iab = np.stack((I, a, b), axis=1)

    # 2. Iab -> LMS_TM
    lms_tm = iab @ IAB_TO_LMS_MATRIX.T

    # 3. 逆向非线性压缩 (1 / 0.43)
    power = 1.0 / 0.43
    lms = np.sign(lms_tm) * (np.abs(lms_tm)**power)

    # 4. LMS -> XYZ_D65
    xyz_d65 = lms @ LMS_TO_XYZ_MATRIX.T
    
    # 5. 缩放回 [0, 100] 范围
    xyz = xyz_d65 * 100.0

    return xyz.flatten() if is_single_vector else xyz

def cat16(
    xyz_source: np.ndarray,
    xyz_w_source: np.ndarray,
    xyz_w_ref: np.ndarray,
    la: float,
    f_level: float
) -> np.ndarray:
    """
    使用 CAT16 色彩适应模型对XYZ颜色进行转换 (向量化实现)。
    此函数将一个或多个源颜色 (在源光照下) 转换到在目标 (参考) 光照下
    人眼所感知的对应颜色。该实现完全向量化，以高效处理 N 个颜色。
    Args:
        xyz_source (np.ndarray): 源颜色的 (N, 3) XYZ三刺激值数组。
        xyz_w_source (np.ndarray): 源光照白点的 (3,) XYZ值向量。
        xyz_w_ref (np.ndarray): 参考光照白点的 (3,) XYZ值向量。
        la_source (float): 源适应场的亮度 (单位: cd/m^2)。
        la_ref (float): 参考适应场的亮度 (单位: cd/m^2)。
        f_level (float): 适应程度因子 (0.8 到 1.0 之间)。
    Returns:
        np.ndarray: 在参考观看条件下对应的 (N, 3) XYZ三刺激值数组。
    """
    xyz_w_source = np.asarray(xyz_w_source, dtype=np.float64).reshape(1, 3)
    xyz_w_ref = np.asarray(xyz_w_ref, dtype=np.float64).reshape(1, 3)
    is_single_vector = (xyz_source.ndim == 1)
    if is_single_vector:
        xyz_source = xyz_source.reshape(1, 3)
    # --- 1. 计算源和参考条件下的适应度 D ---
    # D = F * (1 - (1/3.6) * exp((-L_A - 42) / 92))
    # 使用一个辅助函数或直接计算来避免代码重复
    d_source = f_level * (1.0 - (1.0 / 3.6) * np.exp((-la - 42.0) / 92.0))
    d_ref = f_level * (1.0 - (1.0 / 3.6) * np.exp((-la - 42.0) / 92.0))
    
    # 将 D 限制在 [0, 1] 区间内
    d_source = np.clip(d_source, 0.0, 1.0)
    d_ref = np.clip(d_ref, 0.0, 1.0)

    # --- 2. 将白点转换为锥细胞响应 (LMS/RGB) ---
    # (3, 3) @ (3,) -> (3,)
    rgb_w_source = xyz_w_source.dot(CAT16_XYZ_TO_RGB.T)
    rgb_w_ref = xyz_w_ref.dot(CAT16_XYZ_TO_RGB.T)

    # --- 3. 计算每个锥细胞的适应因子 (向量化) ---
    # D_factor = D * (Y_w / RGB_w) + 1 - D
    y_w_source = xyz_w_source.flatten()[1]
    y_w_ref = xyz_w_ref.flatten()[1]

    d_factor_source = d_source * (y_w_source / rgb_w_source) + 1.0 - d_source
    d_factor_ref = d_ref * (y_w_ref / rgb_w_ref) + 1.0 - d_ref

    # --- 4. 计算最终的适应比率 (向量化) ---
    # (3,) / (3,) -> (3,)
    adaptation_ratios = d_factor_source / d_factor_ref

    # --- 5. 主要变换步骤 (完全向量化) ---
    # 步骤 5a: 将源 XYZ 批量转换为锥细胞响应 (LMS/RGB)
    # (N, 3) @ (3, 3) -> (N, 3)
    rgb_source = xyz_source @ CAT16_XYZ_TO_RGB_T

    # 步骤 5b: 应用色彩适应变换
    # (N, 3) * (3,) -> (N, 3)  (NumPy广播机制)
    rgb_adapted = rgb_source * adaptation_ratios

    # 步骤 5c: 将适应后的锥细胞响应批量转换回 XYZ
    # (N, 3) @ (3, 3) -> (N, 3)
    xyz_adapted = rgb_adapted @ CAT16_RGB_TO_XYZ_T
    
    # --- 6. 后处理 ---
    # 将任何可能的负值修正为 0
    # np.maximum 比 np.clip(xyz, 0, None) 略微高效
    # 不做裁剪，直接返回
    return xyz_adapted

def xyz_to_scam(xyz, xyz_w, y_b, l_a, surround, mode='lite'):
    """
    将 XYZ 三刺激值转换为 sCAM (简单色彩外观模型) 坐标。
    此函数复现了 Li, M. & Luo, M. R. (2024) 的 sCAM 模型，并增加了 Depth(D) 计算。
    Args:
        xyz (array_like): 测试样本的一个或多个XYZ值。
        xyz_w (array_like): 测试光源的白点XYZ值。
        y_b (float): 背景的亮度因子。
        l_a (float): 适应场的亮度 (cd/m^2)。
        surround (str): 环绕观看条件 ('avg', 'dim', 'dark')。
        mode (str, optional): 控制输出的模式。
            - 'lite' (默认): 只返回核心的 JCh 坐标。
            - 'full': 返回 JCh 坐标和包含 H, D 的外部标度。
    Returns:
        numpy.ndarray or tuple[numpy.ndarray, numpy.ndarray]:
        - 如果 mode='lite': 返回一个 (n, 3) 的 ndarray，包含 [J, C, h]。
        - 如果 mode='full': 返回一个元组，包含两个 ndarray：
          - 第一个元素是 (n, 3) 的 [J, C, h]。
          - 第二个元素是 (n, 2) 的 [H, D] (外部标度)。
        如果输入是单个向量，返回的数组也会被相应地展平。
    """
    # --- 输入验证和处理 ---
    if mode not in ['lite', 'full']:
        raise ValueError("mode 参数必须是 'lite' 或 'full'。")
    
    xyz = np.asarray(xyz, dtype=np.float64)
    xyz_w = np.asarray(xyz_w, dtype=np.float64)
    is_single_vector = (xyz.ndim == 1)
    if is_single_vector:
        xyz = xyz.reshape(1, -1)
    if xyz_w.ndim == 1:
        xyz_w = xyz_w.reshape(1, -1)

    # --- 步骤 1-7: 计算核心 J, C, h ---
    # (这部分计算逻辑与之前版本相同)
    s_params = SCAM_SURROUND_PARAMS.get(surround.lower())
    if s_params is None:
        raise ValueError("surround 参数必须是 'avg', 'dim', 或 'dark' 之一。")
    F, c, Fm = s_params['F'], s_params['c'], s_params['Fm']
    
    n = y_b / xyz_w[:, 1, np.newaxis]
    z = 1.48 + np.sqrt(n)
    
    l_w = l_a * 100.0 / y_b
    xyz_wt = XYZW_D65
    xyz_d65 = cat16(xyz, xyz_w, xyz_wt, l_a, F)
    
    iab = xyz_to_sucs(xyz_d65)
    I, a, b = iab[:, 0], iab[:, 1], iab[:, 2]
    
    C_1 = np.sqrt(a**2 + b**2)
    
    L = 100 * (I / 100)**(c * z.flatten())
    J = L
    
    h = np.rad2deg(np.arctan2(b, a))
    h[h < 0] += 360

    # 将核心结果打包成 ndarray
    jch_output = np.stack((J, C_1, h), axis=1)

    # --- 根据模式返回结果 ---
    if mode == 'lite':
        return jch_output.flatten() if is_single_vector else jch_output

    # --- mode='full' 的额外计算 ---
    # 步骤 8: 计算色相构成 (H)
    h_i, e_i, H_i = HUE_DATA['h_i'], HUE_DATA['e_i'], HUE_DATA['H_i']
    h_mod = np.where(h < h_i[0], h + 360, h)
    indices = np.searchsorted(h_i, h_mod, side='right') - 1
    
    temp = (h_mod - h_i[indices]) / e_i[indices]
    temp1 = (h_i[indices + 1] - h_mod) / e_i[indices + 1]
    H = H_i[indices] + 100 * temp / (temp + temp1)

    # **新增**: 计算 Depth (D)
    # 使用 np.power(arr, 3) 确保对整个数组进行立方运算
    D = 1.3 * np.sqrt(np.power(100 - J, 2) + 1.6 * np.power(C_1, 2))

    # 将外部标度打包成 ndarray
    external_scales_output = np.stack((H, D), axis=1)
    
    if is_single_vector:
        return jch_output.flatten(), external_scales_output.flatten()
    
    return jch_output, external_scales_output

def scam_to_xyz(jch, xyz_w, y_b, l_a, surround):
    """
    Converts sCAM (J, C, h) perceptual correlates back to XYZ tristimulus values.
    This function performs the inverse operations of the xyz_to_scam model.
    Args:
        jch (array_like): One or more sCAM values. Input is assumed to be
                          [J, C, h]. Can be a single vector or an Nx3 array.
        xyz_w (array_like): Whitepoint XYZ values of the original illuminant.
        y_b (float): Luminance factor of the background.
        l_a (float): Luminance of the adapting field (in cd/m^2).
        surround (str): Viewing surround condition ('avg', 'dim', 'dark').
    Returns:
        numpy.ndarray: The corresponding XYZ values in the original illuminant.
    """
    # --- Input Validation and Processing ---
    jch = np.asarray(jch, dtype=np.float64)
    xyz_w = np.asarray(xyz_w, dtype=np.float64)
    is_single_vector = (jch.ndim == 1)
    if is_single_vector:
        jch = jch.reshape(1, -1)
    if xyz_w.ndim == 1:
        xyz_w = xyz_w.reshape(1, -1)

    s_params = SCAM_SURROUND_PARAMS.get(surround.lower())
    if s_params is None:
        raise ValueError("surround parameter must be one of 'avg', 'dim', or 'dark'.")
    F, c = s_params['F'], s_params['c']
    l_w = l_a * 100.0 / y_b
    # --- Step 1: Invert Lightness (J) to get I ---
    # This reverses the Stevens effect and surround adjustments.
    n = y_b / xyz_w[:, 1, np.newaxis]
    z = 1.48 + np.sqrt(n)
    
    J = jch[:, 0]
    I = 100.0 * (J / 100.0)**(1.0 / (c * z.flatten()))

    # --- Step 2: Convert Chroma and hue (Ch) to opponent components (a, b) ---
    C = jch[:, 1]
    h = jch[:, 2]
    h_rad = np.deg2rad(h)
    
    a = C * np.cos(h_rad)
    b = C * np.sin(h_rad)

    # --- Step 3: Reconstruct Iab and convert to XYZ_D65 ---
    # We now have the components for the sUCS space.
    iab = np.stack((I, a, b), axis=1)
    xyz_d65 = sucs_to_xyz(iab)
    xyz_wt = XYZW_D65
    # --- Step 4: Apply INVERSE Chromatic Adaptation (from D65 to original illuminant) ---
    # To reverse the adaptation, we use cat16 but swap the source and target whitepoints.
    # The source is now D65, and the target is the original illuminant's whitepoint.
    xyz_final = cat16(xyz_d65, xyz_wt, xyz_w, l_a, F)
    
    return xyz_final.flatten() if is_single_vector else xyz_final

# ==============================================================================

if __name__ == '__main__':

    # 定义 sCAM 模型的观看条件
    viewing_conditions = {
        'xyz_w':    np.array([95.047, 100.0, 108.883]), # 光源白点 (D65)
        'y_b':      20.0,  # 背景的亮度因子 (20% 灰色背景)
        'l_a':      64.0,  # 适应场的亮度 (cd/m^2)
        'surround': 'avg' # 观看环绕条件: 'avg', 'dim', 或 'dark'
    }

    J_list = [20, 35, 50, 65, 80]
    C_list = [5, 10, 20, 30, 40, 50, 60]
    h_list = list(range(0, 360, 30))  # 0, 30, ..., 330

    print("J\tC\th\tXYZ")
    for J in J_list:
        for C in C_list:
            for h in h_list:
                jch = np.array([J, C, h])
                xyz = scam_to_xyz(jch, **viewing_conditions)
                xyz_str = np.array2string(xyz, precision=6, floatmode='fixed', suppress_small=True)
                print(f"{J}\t{C}\t{h}\t{xyz_str}")