feat: 为multiport ovf添加了rms_error_l1和rms_error_lmax,更新了绘图模块

This commit is contained in:
mayge
2025-10-01 05:06:34 -04:00
parent 4e25538a4a
commit 8f963e7101
8 changed files with 428 additions and 63 deletions

65
docs/poles_match.md Normal file
View File

@@ -0,0 +1,65 @@
下面把“极点识别→一对一配对”作为线性指派问题的优化目标,用公式列出(含常用代价设计与未匹配处理)。
记号
- 集合 A参考/上一时刻/阶次 k第 i 个模式的频率、阻尼、振型为 $f_i, ζ_i, φ_i$
- 集合 B估计/下一时刻/阶次 k+1第 j 个模式为 $f'_j, ζ'_j, φ'_j$
- 模态相似度 MAC:
$$
\mathrm{MAC}(\phi_i,\phi'_j)=\frac{\bigl|\phi_i^{H}\phi'_j\bigr|^2}{(\phi_i^{H}\phi_i)\,(\phi'_j{}^{H}\phi'_j)}\in[0,1]
$$
一、配对代价 C_ij 的常用设计
- 加权和(绝对差版本):
$$
C_{ij} = w_f\frac{|f_i-f'_j|}{f_{\mathrm{scale}}} + w_\zeta\,|\,\zeta_i-\zeta'_j\,|+ w_{\mathrm{mac}}\bigl(1-\mathrm{MAC}_{ij}\bigr)
$$
- 或加权平方差版本:
$$
C_{ij}= w_f\Bigl(\frac{f_i-f'_j}{f_{\mathrm{scale}}}\Bigr)^2+ w_\zeta\bigl(\zeta_i-\zeta'_j\bigr)^2+ w_{\mathrm{mac}}\bigl(1-\mathrm{MAC}_{ij}\bigr)^2
$$
- 门控Big-M 罚):
$$
C_{ij}\;:=\;
\begin{cases}
C_{ij}, & \frac{|f_i-f'_j|}{f_{\mathrm{scale}}}\le\tau_f,\;\;|\,\zeta_i-\zeta'_j\,|\le\tau_\zeta,\;\;\mathrm{MAC}_{ij}\ge\tau_{\mathrm{mac}}\\[4pt]
M, & \text{否则}
\end{cases}
\qquad(M\text{ 很大})
$$
二、允许“未匹配”的指派目标(显式未匹配变量)
令决策变量 x_{ij}\in\{0,1\} 表示 i 是否配给 ju_i,v_j 表示未匹配。
$$
\min_{x,u,v}\;
\sum_{i=1}^{m}\sum_{j=1}^{n} C_{ij}\,x_{ij}
+\tau_A\sum_{i=1}^{m} u_i
+\tau_B\sum_{j=1}^{n} v_j
$$
约束:
$$
\sum_{j=1}^{n} x_{ij}+u_i=1,\;\; \forall i=1,\dots,m
\qquad
\sum_{i=1}^{m} x_{ij}+v_j=1,\;\; \forall j=1,\dots,n
$$
$$
x_{ij}\in\{0,1\},\;\;u_i\in\{0,1\},\;\;v_j\in\{0,1\}
$$
三、允许“未匹配”的等价“补方阵”形式(虚拟结点)
取 L=\max(m,n),把代价矩阵补成
$\tilde C\in\mathbb{R}^{L\times L}$,对虚拟匹配赋常数代价 $C_{\mathrm{unmatch}}$,解标准 LSAP
$$
\min_{X\in\{0,1\}^{L\times L}}\;\langle \tilde C, X\rangle
\quad\text{s.t.}\quad
X\mathbf{1}=\mathbf{1},\;\; \mathbf{1}^\top X=\mathbf{1}^\top
$$
其中 X 为置换矩阵;落在补齐的虚拟行/列上即表示“未匹配”,其代价由 $C_{\mathrm{unmatch}}$ 决定。
四、最大化相似度的等价形式
若用相似度矩阵 $S_{ij}$(如 $S_{ij}= \alpha_{\mathrm{mac}}\mathrm{MAC}_{ij}-\alpha_f\frac{|f_i-f'_j|}{f_{\mathrm{scale}}}-\alpha_\zeta|\,\zeta_i-\zeta'_j\,|$),则:
$$
\max_{X}\;\sum_{ij} S_{ij} X_{ij}
\quad\Longleftrightarrow\quad
\min_{X}\;\sum_{ij} (-S_{ij}) X_{ij}
$$
并沿用上面的指派约束与“未匹配”处理。

View File

@@ -9,6 +9,8 @@ from ovf.core.sample import auto_select_multple_ports
from ovf.core.basis.MultiPortOrthonormalBasis import MultiPortOrthonormalBasis
from ovf.core.geometry.plot_poles import plot_poles_in_3d
from ovf.schemas.geometry.poles import PolesPlot3dUnit,PolesPlot3dDataUnit
from ovf.schemas.geometry.poles import PolesPlot2dUnit,PolesPlot2dDataUnit
from ovf.core.geometry.plot_poles import plot_poles_in_2d
import pickle
@dataclass
@@ -21,6 +23,7 @@ class GVFManager:
def __init__(self):
self.parameter_type: Literal["s","y","z"] = "s"
self.datasets: List[GVFDataUnit] = []
self.full_freqs: np.ndarray | None = None
def save(self,filepath:str):
os.makedirs(os.path.dirname(filepath),exist_ok=True)
@@ -34,7 +37,16 @@ class GVFManager:
assert isinstance(obj,GVFManager),"The loaded object is not a GVFManager instance."
return obj
def load_from_datasets(self,jsonfile:str,npoles_cplx:int=2,max_iterations:int=5,max_points:int|None=20,basis:type=MultiPortOrthonormalBasis,parameter_type:Literal["s","y","z"]="s"):
def load_from_datasets(self,
jsonfile:str,
npoles_cplx:int=2,
max_iterations:int=5,
max_points:int|None=20,
basis:type=MultiPortOrthonormalBasis,
parameter_type:Literal["s","y","z"]="s",
min_freqs:float|None=None,
max_freqs:float|None=None
):
self.parameter_type = parameter_type
with open(jsonfile,"r") as f:
datas = json.load(f)
@@ -47,17 +59,27 @@ class GVFManager:
full_freqences = network.f
min_index = 0
max_index = len(full_freqences)-1
for i in range(len(full_freqences)):
if min_freqs is not None and full_freqences[i] < min_freqs:
min_index = i + 1
if max_freqs is not None and full_freqences[i] > max_freqs:
max_index = i - 1
if parameter_type == "s":
sampled_points = network.s.reshape(-1,ports,ports)
sampled_points = network.s.reshape(-1,ports,ports)[min_index:max_index+1]
elif parameter_type == "y":
sampled_points = network.y.reshape(-1,ports,ports)
sampled_points = network.y.reshape(-1,ports,ports)[min_index:max_index+1]
elif parameter_type == "z":
sampled_points = network.z.reshape(-1,ports,ports)
sampled_points = network.z.reshape(-1,ports,ports)[min_index:max_index+1]
self.full_freqs = full_freqences[min_index:max_index+1]
if max_points:
H,freqs = auto_select_multple_ports(sampled_points,full_freqences,max_points=max_points)
H,freqs = auto_select_multple_ports(sampled_points,self.full_freqs,max_points=max_points)
else:
H,freqs = sampled_points,full_freqences
H,freqs = sampled_points,self.full_freqs
vf = VFManager(npoles_cplx=npoles_cplx,freqs=freqs,H=H,model=basis,iterations=K,verbose=False)
vf.fit()
@@ -93,7 +115,7 @@ class GVFManager:
plot_poles_in_3d(unit, save_path)
def plot_vf_responses_with_index(self,save_dir:str,index:int):
def plot_vf_responses_with_index(self,save_dir:str,index:int,freqrange:List[float]|np.ndarray|None=None):
ds = self.datasets[index]
id = f"{index}"
os.makedirs(save_dir,exist_ok=True)
@@ -101,9 +123,12 @@ class GVFManager:
vf.plot_metrics(show=False,save_path=f"{save_dir}/{id}")
full_freqences = vf.freqs
model_responses = vf.get_model_responses(full_freqences)
vf.plot_model_responses(show=False,save_path=f"{save_dir}/{id}")
if freqrange is None:
vf.plot_model_responses(full_freqs=self.full_freqs,show=False,save_path=f"{save_dir}/{id}")
else:
vf.plot_model_responses(full_freqs=np.array(freqrange).reshape(-1),show=False,save_path=f"{save_dir}/{id}")
def plot_all_vf_responses(self,save_dir:str):
def plot_all_vf_responses(self,save_dir:str,freqrange:List[float]|np.ndarray|None=None):
for index,ds in enumerate(self.datasets):
id = f"{index}"
os.makedirs(save_dir,exist_ok=True)
@@ -111,13 +136,25 @@ class GVFManager:
vf.plot_metrics(show=False,save_path=f"{save_dir}/{id}")
full_freqences = vf.freqs
model_responses = vf.get_model_responses(full_freqences)
vf.plot_model_responses(show=False,save_path=f"{save_dir}/{id}")
if freqrange is None:
vf.plot_model_responses(full_freqs=self.full_freqs,show=False,save_path=f"{save_dir}/{id}")
else:
vf.plot_model_responses(full_freqs=np.array(freqrange).reshape(-1),show=False,save_path=f"{save_dir}/{id}")
def plot_poles_in_2d(self,save_path:str):
unit = PolesPlot2dUnit(
title="Poles Distribution 2D",
datas=[
PolesPlot2dDataUnit(
poles=ds.vf_manager.poles,
geometries=ds.geometries
) for ds in self.datasets
],
x_label="Real Part",
y_label="Imaginary Part"
)
plot_poles_in_2d(unit, save_path)
def evaluate_dataset(self):
for ds in self.datasets:
rmse = ds.vf_manager.rms_error
condition = ds.vf_manager.condition
row_condition = ds.vf_manager.row_condition
col_condition = ds.vf_manager.col_condition

View File

@@ -42,6 +42,8 @@ class VFManager():
self.eigenval_row_condition = []
self.eigenval_col_condition = []
self.eigenval_rms_error = []
self.eigenval_rms_error_l1 = []
self.eigenval_rms_error_lmax = []
self.model_instance:MultiPortOrthonormalBasis|None = None
self.model_responses_freqs = None
@@ -120,6 +122,8 @@ class VFManager():
eigenval_row_condition=self.eigenval_row_condition,
eigenval_col_condition=self.eigenval_col_condition,
eigenval_rms_error=self.eigenval_rms_error,
eigenval_rms_error_l1=self.eigenval_rms_error_l1,
eigenval_rms_error_lmax=self.eigenval_rms_error_lmax,
fit_constant=self.fit_constant,
fit_proportional=self.fit_proportional,
dc_enforce=self.dc_enforce,
@@ -168,6 +172,8 @@ class VFManager():
instance.eigenval_condition = data["eigenval_condition"].tolist()
instance.eigenval_row_condition = data["eigenval_row_condition"].tolist()
instance.eigenval_col_condition = data["eigenval_col_condition"].tolist()
instance.eigenval_rms_error_l1 = data["eigenval_rms_error_l1"].tolist()
instance.eigenval_rms_error_lmax = data["eigenval_rms_error_lmax"].tolist()
instance.eigenval_rms_error = data["eigenval_rms_error"].tolist()
instance.fit_constant = bool(data["fit_constant"])
instance.fit_proportional = bool(data["fit_proportional"])
@@ -228,7 +234,9 @@ class VFManager():
eigenval_condition,\
eigenval_row_condition,\
eigenval_col_condition,\
eigenval_rms_error = self.model_instance.eigen_metric
eigenval_rms_error,\
eigenval_rms_error_l1,\
eigenval_rms_error_lmax= self.model_instance.eigen_metric
self.least_squares_condition.append(least_squares_condition)
self.least_squares_row_condition.append(least_squares_row_condition)
self.least_squares_col_condition.append(least_squares_col_condition)
@@ -237,18 +245,20 @@ class VFManager():
self.eigenval_row_condition.append(eigenval_row_condition)
self.eigenval_col_condition.append(eigenval_col_condition)
self.eigenval_rms_error.append(eigenval_rms_error)
self.eigenval_rms_error_l1.append(eigenval_rms_error_l1)
self.eigenval_rms_error_lmax.append(eigenval_rms_error_lmax)
return self.model_instance
def plot_metrics(self,show:bool=True,save_path=None):
plt.figure(figsize=(16, 12))
plt.subplot(4, 2, 1)
plt.figure(figsize=(20, 12))
plt.subplot(5, 2, 1)
plt.plot(
range(1, len(self.least_squares_condition) + 1),
self.least_squares_condition,
label='Least Squares Condition'
)
plt.legend()
plt.subplot(4, 2, 2)
plt.subplot(5, 2, 2)
plt.plot(
range(1, len(self.least_squares_row_condition) + 1),
self.least_squares_row_condition,
@@ -256,7 +266,7 @@ class VFManager():
)
plt.legend()
plt.subplot(4, 2, 3)
plt.subplot(5, 2, 3)
plt.plot(
range(1, len(self.least_squares_col_condition) + 1),
self.least_squares_col_condition,
@@ -264,7 +274,7 @@ class VFManager():
)
plt.legend()
plt.subplot(4, 2, 4)
plt.subplot(5, 2, 4)
plt.plot(
range(1, len(self.least_squares_rms_error) + 1),
self.least_squares_rms_error,
@@ -272,7 +282,7 @@ class VFManager():
)
plt.legend()
plt.subplot(4, 2, 5)
plt.subplot(5, 2, 5)
plt.plot(
range(1, len(self.eigenval_condition) + 1),
self.eigenval_condition,
@@ -280,7 +290,7 @@ class VFManager():
)
plt.legend()
plt.subplot(4, 2, 6)
plt.subplot(5, 2, 6)
plt.plot(
range(1, len(self.eigenval_row_condition) + 1),
self.eigenval_row_condition,
@@ -288,7 +298,7 @@ class VFManager():
)
plt.legend()
plt.subplot(4, 2, 7)
plt.subplot(5, 2, 7)
plt.plot(
range(1, len(self.eigenval_col_condition) + 1),
self.eigenval_col_condition,
@@ -296,13 +306,29 @@ class VFManager():
)
plt.legend()
plt.subplot(4, 2, 8)
plt.subplot(5, 2, 8)
plt.plot(
range(1, len(self.eigenval_rms_error) + 1),
self.eigenval_rms_error,
label='Eigenvalue RMS Error'
)
plt.legend()
plt.subplot(5, 2, 9)
plt.plot(
range(1, len(self.eigenval_rms_error_l1) + 1),
self.eigenval_rms_error_l1,
label='Eigenvalue RMS Error L1'
)
plt.legend()
plt.subplot(5, 2, 10)
plt.plot(
range(1, len(self.eigenval_rms_error_lmax) + 1),
self.eigenval_rms_error_lmax,
label='Eigenvalue RMS Error Lmax'
)
plt.legend()
if show:
plt.show()
@@ -312,33 +338,37 @@ class VFManager():
os.makedirs(save_path, exist_ok=True)
plt.savefig(f"{save_path}/fitting_metrics.png")
def plot_model_responses(self,show:bool=True,save_path=None):
def plot_model_responses(self,full_freqs,show:bool=True,save_path=None):
assert self.model_responses_freqs is not None and self.model_responses_H is not None, "Please run get_model_responses() first."
full_freqs = np.array(full_freqs)
model_responses = self.get_model_responses(full_freqs)
for i in range(self.nports):
for j in range(self.nports):
plt.figure(figsize=(12, 6))
plt.subplot(2, 2, 1)
plt.plot(self.freqs, np.abs(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
plt.plot(self.model_responses_freqs, np.abs(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.plot(full_freqs, np.abs(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.title(f"Response i={i+1}, j={j+1}")
plt.ylabel("Magnitude")
plt.legend(loc="best")
plt.subplot(2, 2, 2)
plt.plot(self.freqs, np.angle(self.H[:,i,j],deg=True), 'o', ms=4, color='red', label='Input Samples')
plt.plot(self.model_responses_freqs, np.angle(self.model_responses_H[:,i,j],deg=True), '-', lw=2, color='k', label='Fit')
plt.plot(full_freqs, np.angle(model_responses[:,i,j],deg=True), '-', lw=2, color='k', label='Fit')
plt.title(f"Response i={i+1}, j={j+1}")
plt.ylabel("Phase (deg)")
plt.legend(loc="best")
plt.tight_layout()
plt.subplot(2, 2, 3)
plt.plot(self.freqs, np.real(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
plt.plot(self.model_responses_freqs, np.real(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.plot(full_freqs, np.real(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.title(f"Response i={i+1}, j={j+1}")
plt.ylabel("Real Part")
plt.legend(loc="best")
plt.subplot(2, 2, 4)
plt.plot(self.freqs, np.imag(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
plt.plot(self.model_responses_freqs, np.imag(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.plot(full_freqs, np.imag(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.title(f"Response i={i+1}, j={j+1}")
plt.ylabel("Imag Part")
plt.legend(loc="best")

View File

@@ -13,6 +13,8 @@ class MultiPortOrthonormalBasis:
self._eigenval_row_condition = None
self._eigenval_col_condition = None
self._eigenval_rms_error = None
self._eigenval_rms_error_l1 = None
self._eigenval_rms_error_lmax = None
self.Cr = None
self.dc_tol = 1e-18
@@ -28,8 +30,8 @@ class MultiPortOrthonormalBasis:
self.s = self.freqs * 2j * np.pi
self.pre_poles = pre_poles
self.Phi = self.generate_basis(self.s, self.pre_poles)
self.A = self.matrix_A(self.pre_poles)
self.Phi = self.generate_basis_confluent(self.s, self.pre_poles)
self.A = self.matrix_A_confluent(self.pre_poles)
self.B = self.vector_B(self.pre_poles)
self.C,self.w0,self.e = self.fit_denominator(self.H)
self.D = self.w0
@@ -65,8 +67,15 @@ class MultiPortOrthonormalBasis:
self._eigenval_condition,\
self._eigenval_row_condition,\
self._eigenval_col_condition,\
self._eigenval_rms_error = self._eigen_metric()
return self._eigenval_condition, self._eigenval_row_condition, self._eigenval_col_condition, self._eigenval_rms_error
self._eigenval_rms_error,\
self._eigenval_rms_error_l1,\
self._eigenval_rms_error_lmax= self._eigen_metric()
return self._eigenval_condition,\
self._eigenval_row_condition,\
self._eigenval_col_condition,\
self._eigenval_rms_error,\
self._eigenval_rms_error_l1,\
self._eigenval_rms_error_lmax
@property
@@ -101,12 +110,15 @@ class MultiPortOrthonormalBasis:
Hk = self.H
Hk_fitted = self.get_model_responses(self.freqs)
rms = np.sqrt(np.mean(np.abs(Hk - Hk_fitted)**2))
rms = np.sqrt(np.mean(np.abs(Hk - Hk_fitted)**2)) / np.sqrt(np.mean(np.abs(Hk)**2))
rms_l1 = np.mean(np.abs(Hk - Hk_fitted)) / np.mean(np.abs(Hk))
rms_lmax = np.max(np.abs(Hk - Hk_fitted)) / np.max(np.abs(Hk))
row_cond = cond_row_inf(self.A - self.B @ self.C)
col_cond = cond_col_one(self.A - self.B @ self.C)
return cond,row_cond,col_cond,rms
return cond,row_cond,col_cond,rms,rms_l1,rms_lmax
def _least_squares_metric(self,Q,R,x,b):
"""Return condition number and RMS error of least-squares matrix A and rhs b."""
@@ -174,6 +186,187 @@ class MultiPortOrthonormalBasis:
ap_list.append(ap)
Phi = np.column_stack(cols).astype(np.complex128)
return Phi
@staticmethod
def _group_poles_confluent(poles, tol=1e-12):
"""
把 LHP 极点数组 `poles`(复数,含共轭对)按重复进行分组。
返回一个 list每项
{"type":"real","ap":ap_RHP,"m":m}
{"type":"pair","ap":ap_RHP,"ap1":ap_RHP_conj,"m":m}
其中 ap = -pRHPRe>0m 是重数。
"""
z = np.asarray(poles, np.complex128).copy()
used = np.zeros(len(z), dtype=bool)
groups = []
i = 0
while i < len(z):
if used[i]:
i += 1
continue
p = z[i]
ap = -p # 你的代码里用 ap = -pRHP
if ap.real < 0:
raise ValueError("poles must be in the LHP")
if np.isclose(p.imag, 0.0, atol=tol):
# 实极点:统计重数
mask = (~used) & np.isclose(z.real, p.real, atol=tol) & np.isclose(z.imag, p.imag, atol=tol)
idx = np.where(mask)[0]
used[idx] = True
groups.append({"type": "real", "ap": ap, "m": len(idx)})
else:
# 找共轭
j = None
for k in range(i+1, len(z)):
if used[k]:
continue
if np.isclose(z[k], np.conj(p), atol=tol):
j = k; break
if j is None:
# 容错:没配上就当作单个复极点(很罕见);按一对处理也可
used[i] = True
groups.append({"type": "real", "ap": ap, "m": 1})
else:
# 统计这对的重数 m
used[i] = used[j] = True
m = 1
# 继续往后找更多相同位置的对
while True:
# 在未使用的剩余里各找一枚匹配
ii = None; jj = None
for k in range(len(z)):
if not used[k] and np.isclose(z[k], p, atol=tol):
ii = k; break
for k in range(len(z)):
if not used[k] and np.isclose(z[k], np.conj(p), atol=tol):
jj = k; break
if ii is not None and jj is not None:
used[ii] = used[jj] = True
m += 1
else:
break
ap1 = -np.conj(p) # = conj(ap)
groups.append({"type":"pair", "ap": ap, "ap1": ap1, "m": m})
i += 1
return groups
def generate_basis_confluent(self,s, poles, tol=1e-12):
"""
合流版正交基(与原 generate_basis 接口兼容)。
返回:
Phi : (K, ncol) complex
layout : 列描述(便于后续打包 C
"""
def _all_pass_factor(s, ap):
# ((s - ap*)/(s + ap)),在 s=jw 上幅度为1
return (s - np.conj(ap)) / (s + ap)
def _all_pass_cascade(s, ap_list):
res = 1.0 + 0.0j
for ap in ap_list:
res *= _all_pass_factor(s, ap)
return res
s = np.asarray(s, np.complex128).reshape(-1)
groups = self._group_poles_confluent(poles, tol=tol)
cols = []
layout = []
prev_aps = [] # 进入当前组之前的 all-pass 级联清单(含重复)
Gprev = _all_pass_cascade(s, prev_aps)
for g in groups:
if g["type"] == "real":
ap, m = g["ap"], g["m"]
# 基函数模板sqrt(2 Re ap) * Gprev * [(AP(ap))^r] * 1/(s+ap)
base = 1.0/(s + ap)
AP = _all_pass_factor(s, ap) # 当前极点的 all-pass
if m > 1:
pass
for r in range(m):
phi = np.sqrt(2.0*ap.real) * Gprev * (AP**r) * base
cols.append(phi)
layout.append({"type":"real", "ap":ap, "order":r+1})
# 更新 prev_aps & Gprev把这一组的 m 个 ap 都加入)
for _ in range(m):
prev_aps.append(ap)
Gprev *= _all_pass_factor(s, ap)
elif g["type"] == "pair":
ap, ap1, m = g["ap"], g["ap1"], g["m"]
# 你的原始两列phi1, phi2r=0
# 合流:在它们前乘 [(AP(ap)*AP(ap1))^r]
denom = (s + ap)*(s + ap1)
absap = np.abs(ap)
base1 = (s - absap)/denom
base2 = (s + absap)/denom
APpair = _all_pass_factor(s, ap) * _all_pass_factor(s, ap1) # 成对 all-pass
if m > 1:
pass
for r in range(m):
Gr = Gprev * (APpair**r)
phi1 = np.sqrt(2.0*ap.real) * Gr * base1
phi2 = np.sqrt(2.0*ap.real) * Gr * base2
cols += [phi1, phi2]
layout += [{"type":"pair_Re","ap":ap,"ap1":ap1,"order":r+1},
{"type":"pair_Im","ap":ap,"ap1":ap1,"order":r+1}]
# 更新 prev_aps & Gprev把该对重复 m 次都加入)
for _ in range(m):
prev_aps.append(ap); Gprev *= _all_pass_factor(s, ap)
prev_aps.append(ap1); Gprev *= _all_pass_factor(s, ap1)
else:
raise RuntimeError("unknown group type")
Phi = np.column_stack(cols).astype(np.complex128)
return Phi
def matrix_A_confluent(self,poles, tol=1e-12):
groups = self._group_poles_confluent(poles, tol=tol)
# 先计算总列数 N=基函数数)
N = 0
for g in groups:
if g["type"]=="real":
N += g["m"]
else: # pair
N += 2*g["m"]
A = np.zeros((N, N), float)
col_ptr = 0 # 当前要写入的列起始索引
for g in groups:
if g["type"] == "real":
ap, m = g["ap"], g["m"]
if m > 1:
pass
a = ap.real
for r in range(m):
c = col_ptr # 当前这一列的全局索引
# 对角项
A[c, c] = -a
# 尾部常数 2a对所有后续行
if c+1 < N:
A[c+1:, c] += 2.0*a
col_ptr += 1
elif g["type"] == "pair":
ap, ap1, m = g["ap"], g["ap1"], g["m"]
a = ap.real
absap = float(np.abs(ap))
for r in range(m):
c0 = col_ptr # 这一对的两列 c0, c0+1
# 2×2 核心块(放在对角)
A[c0:c0+2, c0:c0+2] = np.array([[-a, -a-absap],
[-a+absap, -a]], float)
# 尾部常数 2a对所有后续行逐行加到这两列
if c0+2 < N:
A[c0+2:, c0] += 2.0*a
A[c0+2:, c0+1] += 2.0*a
col_ptr += 2
else:
raise RuntimeError("unknown group type")
return A
def matrix_A(self, poles):
def A_col(p:np.complex128,index:int):

View File

@@ -5,28 +5,43 @@ from sklearn.linear_model import Ridge
import os
from ovf.schemas.geometry.basic import GeometryPlot2dUnit, GeometryPlot2dComplexDataUnit
def plot_2d(unit: GeometryPlot2dUnit, dirname: str):
os.makedirs(dirname, exist_ok=True)
fig = make_subplots(rows=1, cols=1)
def plot_2d(unit: GeometryPlot2dUnit, filename: str):
os.makedirs(os.path.dirname(filename), exist_ok=True)
scatters_data = []
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
for index in range(len(unit.datas[0].x)):
x = [unit.datas[i].x[index] for i in range(len(unit.datas))]
y = [unit.datas[i].y[index] for i in range(len(unit.datas))]
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode='markers',
name=data.label
),
row=1, col=1
)
fig.update_layout(
title=unit.title,
xaxis_title=unit.x_label,
yaxis_title=unit.y_label,
for pt in range(len(unit.datas)):
scatters_data.append(
go.Scatter(
x=[x[pt]],
y=[y[pt]],
mode='markers',
name=None,
marker=dict(size=8, color=colors[index%len(colors)]),
hovertext=f"Geometries: {unit.datas[pt].geometries}<br>X: {x[pt]}<br>Y: {y[pt]}"
)
)
layout = go.Layout(
title=f'{unit.title}',
xaxis=dict(title=f'{unit.x_label}'),
yaxis=dict(title=f'{unit.y_label}'),
showlegend=False
)
fig.write_html(f"{dirname}/plot_2d.html")
fig = go.Figure(data=scatters_data, layout=layout)
if filename.endswith('.html'):
fig.write_html(filename)
elif filename.endswith('.png'):
fig.write_image(filename,format='png')
elif filename.endswith('.pdf'):
fig.write_image(filename,format='pdf')
elif filename.endswith('.jpeg') or filename.endswith('.jpg'):
fig.write_image(filename,format='jpeg')
else:
fig.write_html(f"{filename}.html")

View File

@@ -3,9 +3,10 @@ import plotly.graph_objs as go
from plotly.offline import plot
import numpy as np
from sklearn.linear_model import Ridge
from ovf.schemas.geometry.poles import PolesPlot3dUnit, PolesPlot3dDataUnit
from ovf.schemas.geometry.basic import GeometryPlot3dDataUnit, GeometryPlot3dUnit
from ovf.schemas.geometry.poles import PolesPlot3dUnit, PolesPlot3dDataUnit, PolesPlot2dUnit, PolesPlot2dDataUnit
from ovf.schemas.geometry.basic import GeometryPlot3dDataUnit, GeometryPlot3dUnit,GeometryPlot2dUnit,GeometryPlot2dComplexDataUnit
from ovf.core.geometry.plot3d import get_plot_instance_in_3d
from ovf.core.geometry.plot2d import plot_2d
from plotly.subplots import make_subplots
import os
@@ -85,4 +86,16 @@ def plot_poles_in_3d(poles:PolesPlot3dUnit,dirname:str):
index += 1
def plot_poles_in_2d(poles:PolesPlot2dUnit,filename:str):
data = GeometryPlot2dUnit(
title=poles.title,
datas=[GeometryPlot2dComplexDataUnit(
x=[np.real(p) for p in d.poles],
y=[np.imag(p) for p in d.poles],
geometries=d.geometries
) for d in poles.datas],
x_label=poles.x_label,
y_label=poles.y_label
)
plot_2d(data, filename)

View File

@@ -25,7 +25,7 @@ class GeometryPlot3dUnit:
class GeometryPlot2dComplexDataUnit:
x: List[float]
y: List[float]
geometries: Dict[str,Union[float,int,str]]
geometries: Dict[str,float]
@dataclass

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Literal, Union
from typing import List, Literal, Union, Dict
import numpy as np
@dataclass
@@ -27,4 +27,16 @@ class PolesPlot3dUnit:
geometry_2_label: str = "geometry_2"
pole_scale: Literal["linear", "log"] = "linear"
geometry_1_scale: Literal["linear", "log"] = "linear"
geometry_2_scale: Literal["linear", "log"] = "linear"
geometry_2_scale: Literal["linear", "log"] = "linear"
@dataclass
class PolesPlot2dDataUnit:
poles: List[Union[complex,np.complex128]]
geometries: Dict[str,float]
@dataclass
class PolesPlot2dUnit:
title: str
datas: List[PolesPlot2dDataUnit]
x_label: str
y_label: str