Files
ovf/core/sk_iter.py
2025-09-15 11:41:55 -04:00

250 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
from dataclasses import dataclass
from typing import List, Sequence, Tuple, Optional
@dataclass
class OPVFConfig:
n_poles: int
max_iter: int = 5
tol: float = 1e-3
add_constraint: bool = True
real_split: bool = True
lambda_reg: float = 0.0
verbose: bool = True
# ---------- 参数正交基 ----------
class ParamOrthoBasis:
def __init__(self, params: np.ndarray, total_degree: int):
# params: (Ns, d)
self.mean = params.mean(0)
self.std = params.std(0) + 1e-15
self.exps = self._gen_exps(params.shape[1], total_degree)
M = self._build_monomial_matrix(params) # (Ns,M)
Q,R = np.linalg.qr(M)
self.Q = Q
self.R = R
def _gen_exps(self, dim, D):
exps=[]
def rec(cur, i, rem):
if i==dim:
exps.append(tuple(cur)); return
for k in range(rem+1):
cur.append(k); rec(cur, i+1, rem-k); cur.pop()
rec([],0,D)
return exps
def _build_monomial_matrix(self, params):
X = (params - self.mean)/self.std
Ns = X.shape[0]
M = np.zeros((Ns, len(self.exps)))
for idx,e in enumerate(self.exps):
v = np.ones(Ns)
for j,p in enumerate(e):
if p: v *= X[:,j]**p
M[:,idx]=v
return M
def eval(self, g: np.ndarray) -> np.ndarray:
x = (g - self.mean)/self.std
mono=[]
for e in self.exps:
val=1.0
for j,p in enumerate(e):
if p: val*= x[j]**p
mono.append(val)
mono=np.array(mono)
# φ = mono @ R^{-1}
return mono @ np.linalg.inv(self.R)
# ---------- 频率正交(有理)基 ----------
class RationalFreqBasis:
def __init__(self, freqs: np.ndarray, poles: np.ndarray):
# 构造原始矩阵 G: 列0 常数 1, 后面 1/(s-a_k)
s = 1j*2*np.pi*freqs
cols=[np.ones_like(s)]
for a in poles:
cols.append(1.0/(s - a))
G = np.vstack(cols).T # (Nf, K+1)
# 加权可加 w_f这里统一 1
Q, R = np.linalg.qr(G)
self.Q = Q # Φ(f) (Nf, K+1) 频率正交基取样
self.R = R # 变换: G = Q R
self.freqs = freqs
def eval(self) -> np.ndarray:
return self.Q # 直接返回已正交基取样 (Nf, K+1)
# ---------- 主模型 ----------
class OrthonormalParametricVF:
"""
H(s, μ) = N(s,μ)/D(s,μ)
N(s,μ)= Σ_k c_k(μ) φ_k(s); D(s,μ)= Σ_k ĉ_k(μ) φ_k(s)
其中 φ_k(s) 是频率正交有理基; c_k(μ), ĉ_k(μ) 在参数正交基上展开:
c_k(μ)= Σ_q C[k,q] ψ_q(μ); ĉ_k(μ)= Σ_q Ctilde[k,q] ψ_q(μ)
"""
def __init__(self, cfg: OPVFConfig, freq_basis: RationalFreqBasis, param_basis: ParamOrthoBasis):
self.cfg = cfg
self.fb = freq_basis
self.pb = param_basis
Kp1 = self.fb.Q.shape[1]
Qp = self.pb.Q.shape[1]
self.C = np.zeros((Kp1, Qp), dtype=complex) # 分子
self.Ct = np.zeros((Kp1, Qp), dtype=complex) # 分母
# 初始化分母ĉ_0 ≈ 1其余 0
self.Ct[0,0] = 1.0
def _assemble_phi_param(self, params: np.ndarray) -> np.ndarray:
# 返回 (Ns, Qp)
return self.pb.Q
def fit(self, H_samples: List[np.ndarray], params: np.ndarray):
"""
H_samples: list 长度 Ns, 每个 (Nf,) 复
params: (Ns,d)
"""
Φf = self.fb.eval() # (Nf, K+1)
Ns = len(H_samples)
Nf, Kp1 = Φf.shape
Φμ = self._assemble_phi_param(params) # (Ns, Qp)
Qp = Φμ.shape[1]
# t=0: 只解 C (式3, D=1)
self._solve_t0(H_samples, Φf, Φμ)
D_prev = self._eval_D(Φf, Φμ) # (Nf, Ns)
for it in range(1, self.cfg.max_iter+1):
A, b = self._build_iter_system(H_samples, Φf, Φμ, D_prev)
if self.cfg.lambda_reg>0:
lam = self.cfg.lambda_reg
A = np.vstack([A, lam*np.eye(A.shape[1])])
b = np.concatenate([b, np.zeros(A.shape[1])])
x, *_ = np.linalg.lstsq(A, b, rcond=None)
self._unpack_iter_solution(x, Kp1, Qp)
D_new = self._eval_D(Φf, Φμ)
rel = np.max(np.abs(D_new-D_prev)/(np.abs(D_prev)+1e-12))
if self.cfg.verbose:
print(f"[OPVF-Iter {it}] rel_change={rel:.3e}")
D_prev = D_new
if rel < self.cfg.tol:
if self.cfg.verbose:
print(f"[OPVF] converged at {it}")
break
return self
def _solve_t0(self, H_samples, Φf, Φμ):
Ns = len(H_samples); Nf,Kp1 = Φf.shape; Qp = Φμ.shape[1]
# 设计矩阵行数 Ns*Nf未知数 Kp1*Qp
cols = Kp1*Qp
A = np.zeros((Ns*Nf, cols), dtype=complex)
b = np.zeros(Ns*Nf, dtype=complex)
r = 0
for n,Hn in enumerate(H_samples):
phiμ = Φμ[n] # (Qp,)
# Φf (Nf,Kp1) 外积 phiμ -> (Nf, Kp1*Qp)
blk = np.einsum('fk,q->fkq', Φf, phiμ).reshape(Nf, cols)
A[r:r+Nf,:] = blk
b[r:r+Nf] = Hn
r += Nf
x, *_ = np.linalg.lstsq(A, b, rcond=None)
self.C = x.reshape(Kp1, Qp)
# 分母保持初始 (ĉ_0=1)
def _build_iter_system(self, H_samples, Φf, Φμ, D_prev):
Ns = len(H_samples); Nf,Kp1 = Φf.shape; Qp=Φμ.shape[1]
# 未知C (Kp1*Qp) + Ct (Kp1*Qp) 但固定 Ct[0,0]=1可去掉该变量
mask_fix = np.zeros((Kp1,Qp), dtype=bool)
mask_fix[0,0]=True
idx_map={}
col=0
for i in range(Kp1):
for j in range(Qp):
idx_map[('C',i,j)] = col; col+=1
for i in range(Kp1):
for j in range(Qp):
if mask_fix[i,j]: continue
idx_map[('Ct',i,j)] = col; col+=1
cols = col
rows = Ns*Nf
A = np.zeros((rows, cols), dtype=complex)
b = np.zeros(rows, dtype=complex)
r=0
for n,Hn in enumerate(H_samples):
phiμ = Φμ[n]
Dp = D_prev[:,n] # (Nf,)
invDp = 1.0/(Dp + 1e-12)
# 分子块: (Φf φμ^T) / D_prev
Num_blk = np.einsum('fk,q->fkq', Φf, phiμ).reshape(Nf, Kp1*Qp)
Num_blk = (Num_blk.T * invDp).T
# 填 C 部分
for i in range(Kp1):
for j in range(Qp):
A[r:r+Nf, idx_map[('C',i,j)]] = Num_blk[:, i*Qp + j]
# 分母块: -(H * Φf φμ^T)/D_prev
Den_blk = (Num_blk.T * Hn).T * (-1.0)
for i in range(Kp1):
for j in range(Qp):
if mask_fix[i,j]: continue
A[r:r+Nf, idx_map[('Ct',i,j)]] = Den_blk[:, i*Qp + j]
# 右端 0
r += Nf
# 约束行 (式5)(8) 可选Re( Σ_k D^{(t)}/D^{(t-1)} ) = K+1
if self.cfg.add_constraint:
row_c = np.zeros(cols, dtype=complex)
# D^{(t)} ≈ Σ Ct_k φ_k(s) ; 用代表样本 n=0, 对所有频率平均
n0=0
phiμ0=Φμ[n0]
mean_phiμ = phiμ0 # 可换平均
# φ_k(s) 平均
mean_phif = Φf.mean(0) # (Kp1,)
for i in range(Kp1):
for j in range(Qp):
if mask_fix[i,j]: continue
row_c[idx_map[('Ct',i,j)]] = mean_phif[i]*mean_phiμ[j]
rhs = (Kp1) # K+1
A = np.vstack([A, np.real(row_c) if self.cfg.real_split else row_c])
b = np.concatenate([b, [rhs]])
# 拆实虚
if self.cfg.real_split:
A_real = np.vstack([np.real(A), np.imag(A)])
b_real = np.concatenate([np.real(b), np.imag(b)])
else:
A_real, b_real = A, b
return A_real, b_real
def _unpack_iter_solution(self, x, Kp1, Qp):
# 重新填回 C, Ct (保持 Ct[0,0]=1)
# 构建与 _build_iter_system 相同的 idx_map
mask_fix = np.zeros((Kp1,Qp), dtype=bool); mask_fix[0,0]=True
idx_C = Kp1*Qp
# 注意:我们当时 C 索引从 0..(Kp1*Qp-1)
self.C = x[:idx_C].reshape(Kp1, Qp)
Ct_new = self.Ct.copy()
pos = idx_C
for i in range(Kp1):
for j in range(Qp):
if mask_fix[i,j]: continue
Ct_new[i,j]=x[pos]; pos+=1
self.Ct = Ct_new
def _eval_C_mu(self, phiμ):
# 返回 c_k(μ) (Kp1,)
return self.C @ phiμ
def _eval_Ct_mu(self, phiμ):
return self.Ct @ phiμ
def _eval_D(self, Φf, Φμ):
# D(f, sample)= Σ_k ĉ_k(μ_n) φ_k(f)
Kp1 = Φf.shape[1]
Ns = Φμ.shape[0]
D = np.zeros((Φf.shape[0], Ns), dtype=complex)
for n in range(Ns):
ct_mu = self._eval_Ct_mu(Φμ[n])
D[:,n] = Φf @ ct_mu
return D
def evaluate(self, freqs: np.ndarray, g: np.ndarray):
assert np.allclose(freqs, self.fb.freqs), "频率需在训练网格上 (示例简化)"
Φf = self.fb.Q # (Nf,K+1)
phiμ = self.pb.eval(g) # (Qp,)
num = Φf @ (self.C @ phiμ)
den = Φf @ (self.Ct @ phiμ)
return num / (den + 1e-15)