Files
ovf/core/sk_iter.py
2025-09-17 02:45:10 -04:00

326 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Tuple
# ---------- 1. 生成初始极点(按论文 Re(-a_p) 小、Im 线性分布) ----------
def generate_starting_poles(n_pairs: int, beta_min: float, beta_max: float, alpha_scale: float = 0.01):
"""
仅生成复共轭对: p = -alpha + j beta, p*。
n_pairs: 复对数量 (总极点数 = 2*n_pairs)
beta_min,beta_max: 想要覆盖的虚部范围 (单位: rad/s)
alpha_scale: alpha = alpha_scale * beta (文中 {α_p}=0.01{β_p})
返回: list[complex] (正虚部先, 后跟共轭)
"""
betas = 2*np.pi*np.linspace(beta_min, beta_max, n_pairs)
poles = []
for b in betas:
alpha = alpha_scale * b
p = -alpha + 1j * b
poles += [p, np.conj(p)]
print(f"生成 {len(poles)} 个初始极点 (复对) {poles}]")
return poles
# ---------- 2. MuntzLaguerre 频率基 (与你现有 orthonormal_basis 中一致的核心) ----------
def muntz_laguerre_basis(freqs_hz: np.ndarray, stable_poles: List[complex]) -> np.ndarray:
"""
返回 Φ_ml (Nf, P) 列顺序: φ_0=1, 后续依次 (单极点 / 复对两列)
"""
s = 1j * 2 * np.pi * freqs_hz
basis = [np.ones_like(s, dtype=complex)]
product = np.ones_like(s, dtype=complex)
k = 0
N = len(stable_poles)
while k < N:
p = stable_poles[k]
if np.real(p) >= 0:
raise ValueError("极点需在左半平面")
if np.imag(p) > 0: # 复对首
if k + 1 >= N or not np.isclose(stable_poles[k+1], np.conj(p)):
raise ValueError("复对未配对")
pc = stable_poles[k+1]
sigma = -np.real(p)
scale = np.sqrt(2 * sigma)
r = np.abs(p)
denom = (s - p) * (s - pc)
phi_p = scale * (s - r) / denom * product
phi_pc = scale * (s + r) / denom * product
basis.extend([phi_p, phi_pc])
# update product
product = product * (s + pc)/(s - p) * (s + p)/(s - pc)
k += 2
elif np.imag(p) < 0:
raise ValueError("负虚部复极点必须跟在正虚部之后")
else: # 实极点
sigma = -np.real(p)
scale = np.sqrt(2 * sigma)
phi = scale / (s - p) * product
basis.append(phi)
product = product * (s + p)/(s - p)
k += 1
return np.column_stack(basis) # (Nf, P)
# ---------- 3. 原始部分分式基 ----------
def raw_partial_fraction_basis(freqs_hz: np.ndarray, stable_poles: List[complex]) -> np.ndarray:
s = 1j * 2 * np.pi * freqs_hz
cols = [np.ones_like(s, dtype=complex)]
for p in stable_poles:
cols.append(1.0 / (s - p))
return np.column_stack(cols) # (Nf, P)
# ---------- 4. 变换矩阵 Φ_ml T = G_raw ----------
def compute_transform_matrix(Phi_ml: np.ndarray, G_raw: np.ndarray, weights: np.ndarray | None = None) -> np.ndarray:
# 加权最小二乘 T = (Φ^H W Φ)^{-1} Φ^H W G
if weights is None:
WPhi = Phi_ml
WG = G_raw
M = Phi_ml.conj().T @ WPhi
RHS = Phi_ml.conj().T @ WG
else:
w = weights[:, None]
M = Phi_ml.conj().T @ (w * Phi_ml)
RHS = Phi_ml.conj().T @ (w * G_raw)
T = np.linalg.solve(M, RHS)
return T # (P,P), Φ T = G
# ---------- 5. 参数多项式正交基 ----------
def generate_param_basis(param_matrix: np.ndarray, total_degree: int):
"""
param_matrix: (Ns, d)
返回:
Qμ: (Ns, V) 正交列
Rμ: (V, V) 上三角
exps: 多指数列表
"""
X = param_matrix
Ns, d = X.shape
# 生成多指数
exps=[]
def rec(cur,i,rem):
if i==d:
exps.append(tuple(cur)); return
for k in range(rem+1):
cur.append(k); rec(cur,i+1,rem-k); cur.pop()
rec([],0,total_degree)
V = len(exps)
M = np.zeros((Ns, V))
for idx,e in enumerate(exps):
v = np.ones(Ns)
for j,p in enumerate(e):
if p: v *= X[:,j]**p
M[:,idx]=v
Q,R = np.linalg.qr(M)
return Q,R,exps
# ---------- 6. t=0 构建设计矩阵并解 C ----------
def fit_t0(H_list: List[np.ndarray], Phi_f: np.ndarray, : np.ndarray) -> np.ndarray:
"""
H_list: 长度 Ns, 每项 (Nf,) 复数
Phi_f: (Nf, P)
Qμ: (Ns, V)
返回: C (P,V)
"""
Ns = len(H_list); Nf,P = Phi_f.shape; V = .shape[1]
cols = P*V
A = np.zeros((Ns*Nf, cols), dtype=complex)
b = np.zeros(Ns*Nf, dtype=complex)
r=0
for n,Hn in enumerate(H_list):
blk = np.einsum('fp,v->fpv', Phi_f, [n]).reshape(Nf, cols)
A[r:r+Nf,:]=blk
b[r:r+Nf]=Hn
r+=Nf
x, *_ = np.linalg.lstsq(A, b, rcond=None)
C = x.reshape(P, V)
return C
# ---------- 7. t=1 (首轮 SK) 解 C, Ct ----------
def fit_t1_SK(H_list: List[np.ndarray], Phi_f: np.ndarray, : np.ndarray,
C_prev: np.ndarray, max_iter: int =1, tol: float =1e-3):
"""
只做一轮 (或少量) SK初始 D=1。
返回: C_new, Ct_new
"""
Ns = len(H_list); Nf,P = Phi_f.shape; V = .shape[1]
# 初始化分母系数 Ct: Ct[0,0]=1
Ct = np.zeros((P,V), dtype=complex)
Ct[0,0]=1.0
C = C_prev.copy()
for it in range(max_iter):
# 构建线性系统: (N - D*H)=0
# 未知顺序: C(:), Ct(:) 去掉固定 Ct[0,0]
mask_fix = np.zeros((P,V), dtype=bool); mask_fix[0,0]=True
col_map={}
col=0
for i in range(P):
for j in range(V):
col_map[('C',i,j)]=col; col+=1
for i in range(P):
for j in range(V):
if mask_fix[i,j]: continue
col_map[('Ct',i,j)]=col; col+=1
A = np.zeros((Ns*Nf, col), dtype=complex)
b = np.zeros(Ns*Nf, dtype=complex)
r=0
# 评价当前 D_prev= Σ Ct φ ψ
for n,Hn in enumerate(H_list):
# 分子块 (Φ_f ⊗ ψ_n)
blk = np.einsum('fp,v->fpv', Phi_f, [n]).reshape(Nf, P*V)
# 填 C 部分
for i in range(P):
for j in range(V):
A[r:r+Nf, col_map[('C',i,j)]] = blk[:, i*V + j]
# 分母块: - H * blk
for i in range(P):
for j in range(V):
if mask_fix[i,j]: continue
A[r:r+Nf, col_map[('Ct',i,j)]] = - Hn * blk[:, i*V + j]
r+=Nf
# 求解
x, *_ = np.linalg.lstsq(A, b, rcond=None)
# 拆回
idx=0
for i in range(P):
for j in range(V):
C[i,j]=x[idx]; idx+=1
Ct_new = Ct.copy()
for i in range(P):
for j in range(V):
if mask_fix[i,j]: continue
Ct_new[i,j]=x[idx]; idx+=1
# 收敛性简单检查:分母变化
diff = np.max(np.abs(Ct_new - Ct))
Ct = Ct_new
if diff < tol:
break
return C, Ct
# ---------- 8. 假极点检测 (在 raw 基上) ----------
def detect_spurious_poles(C_ml: np.ndarray, Ct_ml: np.ndarray, T: np.ndarray,
: np.ndarray,
tau_cancel=1e-2, tau_small=1e-3, eps=1e-14):
"""
C_ml, Ct_ml: (P,V) (Muntz-Laguerre 基)
T: Φ_ml T = G_raw
Qμ: (Ns,V)
"""
# raw 系数: c_raw = T^{-1} c_ml
Tinv = np.linalg.inv(T)
C_raw = Tinv @ C_ml
Ct_raw = Tinv @ Ct_ml
P,V = C_ml.shape
Ns = .shape[0]
r_vals = C_raw @ .T
q_vals = Ct_raw @ .T
metrics={}
scales=[]
for k in range(1,P): # 跳过常数列
rv = r_vals[k]; qv = q_vals[k]
diff = np.max(np.abs(rv - qv))
scale = np.max([np.max(np.abs(rv)), np.max(np.abs(qv)), eps])
eta = diff / scale
metrics[k] = {"diff": diff, "scale": scale, "eta": eta}
scales.append(scale)
S_max = max(scales) if scales else 1.0
cancel_spurious = [k for k in range(1,P)
if metrics[k]["eta"] < tau_cancel and metrics[k]["scale"] > tau_small * S_max]
small_spurious = [k for k in range(1,P)
if metrics[k]["scale"] <= tau_small * S_max]
return {
"cancel_spurious": cancel_spurious,
"small_spurious": small_spurious,
"metrics": metrics,
"S_max": S_max,
"C_raw": C_raw,
"Ct_raw": Ct_raw
}
# ---------- 9. 统一入口 ----------
@dataclass
class OPVFResult:
C_ml: np.ndarray
Ct_ml: np.ndarray
C_raw: np.ndarray
Ct_raw: np.ndarray
T: np.ndarray
: np.ndarray
: np.ndarray
exps: List[Tuple[int]]
poles: List[complex]
spurious: Dict
def opvf_from_H(H_list: List[np.ndarray],
freqs_hz: np.ndarray,
param_matrix: np.ndarray,
total_degree: int,
poles: List[complex],
do_t1: bool = True) -> OPVFResult:
"""
H_list: 长度 Ns; 每项 shape (Nf,)
freqs_hz: (Nf,)
param_matrix: (Ns,d) 归一化前参数值 (内部不自动标准化以保持可控)
poles: 初始稳定极点列表
"""
# 频率基
Phi_ml = muntz_laguerre_basis(freqs_hz, poles) # (Nf,P)
G_raw = raw_partial_fraction_basis(freqs_hz, poles) # (Nf,P)
# 简单权 (ω 权): w = Δf (因 (1/2π)∫ φφ* dω => ∑ Δf)
w = _trap_weights(freqs_hz)
T = compute_transform_matrix(Phi_ml, G_raw, w)
# 参数基
, , exps = generate_param_basis(param_matrix, total_degree)
# t=0
C_ml = fit_t0(H_list, Phi_ml, )
if do_t1:
C_ml, Ct_ml = fit_t1_SK(H_list, Phi_ml, , C_ml, max_iter=1)
else:
# D=1
P,V = C_ml.shape
Ct_ml = np.zeros((P,V), dtype=complex)
Ct_ml[0,0]=1.0
# 假极点检测
spurious = detect_spurious_poles(C_ml, Ct_ml, T, )
return OPVFResult(
C_ml=C_ml,
Ct_ml=Ct_ml,
C_raw=spurious["C_raw"],
Ct_raw=spurious["Ct_raw"],
T=T,
=,
=,
exps=exps,
poles=poles,
spurious=spurious
)
# ---------- 辅助: 梯形权 ----------
def _trap_weights(f: np.ndarray):
if len(f)==1: return np.ones(1)
df = np.diff(f)
w = np.zeros_like(f)
w[0]=0.5*df[0]; w[-1]=0.5*df[-1]
if len(f)>2:
w[1:-1]=0.5*(df[:-1]+df[1:])
return w
# ---------- 简单测试占位 ----------
if __name__ == "__main__":
# 虚构数据: 2 个参数样本 (Ns=2), 频率 200 点
freqs = np.linspace(1e8, 5e9, 200)
# 真正模型 (示例): H = Σ_k R_k/(s - p_k)
true_poles = [-0.5e3 + 1.2e9j, -0.5e3 - 1.2e9j]
s = 1j*2*np.pi*freqs
def synth(Rs):
return Rs[0]/(s - true_poles[0]) + Rs[1]/(s - true_poles[1])
H_list = [synth([0.8+0.1j, 1.2-0.2j]), synth([0.9+0.05j, 1.1+0.1j])]
params = np.array([[0.0],[1.0]]) # 1 维参数
# 给一个冗余极点集合 (含真实 + 额外)
start_poles = generate_starting_poles(n_pairs=10, beta_min=5e8, beta_max=2.0e9, alpha_scale=0.000001)
res = opvf_from_H(H_list, freqs, params, total_degree=1, poles=start_poles)
print("C_ml shape:", res.C_ml.shape)
print("假极点(cancel):", res.spurious["cancel_spurious"])
print("假极点(small):", res.spurious["small_spurious"])