diff --git a/docs/poles_match.md b/docs/poles_match.md new file mode 100644 index 0000000..4e6eaed --- /dev/null +++ b/docs/poles_match.md @@ -0,0 +1,65 @@ +下面把“极点识别→一对一配对”作为线性指派问题的优化目标,用公式列出(含常用代价设计与未匹配处理)。 + +记号 +- 集合 A(参考/上一时刻/阶次 k):第 i 个模式的频率、阻尼、振型为 $f_i, ζ_i, φ_i$ +- 集合 B(估计/下一时刻/阶次 k+1):第 j 个模式为 $f'_j, ζ'_j, φ'_j$ +- 模态相似度 MAC: +$$ +\mathrm{MAC}(\phi_i,\phi'_j)=\frac{\bigl|\phi_i^{H}\phi'_j\bigr|^2}{(\phi_i^{H}\phi_i)\,(\phi'_j{}^{H}\phi'_j)}\in[0,1] +$$ + +一、配对代价 C_ij 的常用设计 +- 加权和(绝对差版本): +$$ +C_{ij} = w_f\frac{|f_i-f'_j|}{f_{\mathrm{scale}}} + w_\zeta\,|\,\zeta_i-\zeta'_j\,|+ w_{\mathrm{mac}}\bigl(1-\mathrm{MAC}_{ij}\bigr) +$$ +- 或加权平方差版本: +$$ +C_{ij}= w_f\Bigl(\frac{f_i-f'_j}{f_{\mathrm{scale}}}\Bigr)^2+ w_\zeta\bigl(\zeta_i-\zeta'_j\bigr)^2+ w_{\mathrm{mac}}\bigl(1-\mathrm{MAC}_{ij}\bigr)^2 +$$ +- 门控(Big-M 罚): +$$ +C_{ij}\;:=\; +\begin{cases} +C_{ij}, & \frac{|f_i-f'_j|}{f_{\mathrm{scale}}}\le\tau_f,\;\;|\,\zeta_i-\zeta'_j\,|\le\tau_\zeta,\;\;\mathrm{MAC}_{ij}\ge\tau_{\mathrm{mac}}\\[4pt] +M, & \text{否则} +\end{cases} +\qquad(M\text{ 很大}) +$$ + +二、允许“未匹配”的指派目标(显式未匹配变量) +令决策变量 x_{ij}\in\{0,1\} 表示 i 是否配给 j;u_i,v_j 表示未匹配。 +$$ +\min_{x,u,v}\; +\sum_{i=1}^{m}\sum_{j=1}^{n} C_{ij}\,x_{ij} ++\tau_A\sum_{i=1}^{m} u_i ++\tau_B\sum_{j=1}^{n} v_j +$$ +约束: +$$ +\sum_{j=1}^{n} x_{ij}+u_i=1,\;\; \forall i=1,\dots,m +\qquad +\sum_{i=1}^{m} x_{ij}+v_j=1,\;\; \forall j=1,\dots,n +$$ +$$ +x_{ij}\in\{0,1\},\;\;u_i\in\{0,1\},\;\;v_j\in\{0,1\} +$$ + +三、允许“未匹配”的等价“补方阵”形式(虚拟结点) +取 L=\max(m,n),把代价矩阵补成 +$\tilde C\in\mathbb{R}^{L\times L}$,对虚拟匹配赋常数代价 $C_{\mathrm{unmatch}}$,解标准 LSAP: +$$ +\min_{X\in\{0,1\}^{L\times L}}\;\langle \tilde C, X\rangle +\quad\text{s.t.}\quad +X\mathbf{1}=\mathbf{1},\;\; \mathbf{1}^\top X=\mathbf{1}^\top +$$ +其中 X 为置换矩阵;落在补齐的虚拟行/列上即表示“未匹配”,其代价由 $C_{\mathrm{unmatch}}$ 决定。 + +四、最大化相似度的等价形式 +若用相似度矩阵 $S_{ij}$(如 $S_{ij}= \alpha_{\mathrm{mac}}\mathrm{MAC}_{ij}-\alpha_f\frac{|f_i-f'_j|}{f_{\mathrm{scale}}}-\alpha_\zeta|\,\zeta_i-\zeta'_j\,|$),则: +$$ +\max_{X}\;\sum_{ij} S_{ij} X_{ij} +\quad\Longleftrightarrow\quad +\min_{X}\;\sum_{ij} (-S_{ij}) X_{ij} +$$ +并沿用上面的指派约束与“未匹配”处理。 \ No newline at end of file diff --git a/ovf/core/GVFManager.py b/ovf/core/GVFManager.py index addcb41..648fe6a 100644 --- a/ovf/core/GVFManager.py +++ b/ovf/core/GVFManager.py @@ -9,6 +9,8 @@ from ovf.core.sample import auto_select_multple_ports from ovf.core.basis.MultiPortOrthonormalBasis import MultiPortOrthonormalBasis from ovf.core.geometry.plot_poles import plot_poles_in_3d from ovf.schemas.geometry.poles import PolesPlot3dUnit,PolesPlot3dDataUnit +from ovf.schemas.geometry.poles import PolesPlot2dUnit,PolesPlot2dDataUnit +from ovf.core.geometry.plot_poles import plot_poles_in_2d import pickle @dataclass @@ -21,6 +23,7 @@ class GVFManager: def __init__(self): self.parameter_type: Literal["s","y","z"] = "s" self.datasets: List[GVFDataUnit] = [] + self.full_freqs: np.ndarray | None = None def save(self,filepath:str): os.makedirs(os.path.dirname(filepath),exist_ok=True) @@ -34,7 +37,16 @@ class GVFManager: assert isinstance(obj,GVFManager),"The loaded object is not a GVFManager instance." return obj - def load_from_datasets(self,jsonfile:str,npoles_cplx:int=2,max_iterations:int=5,max_points:int|None=20,basis:type=MultiPortOrthonormalBasis,parameter_type:Literal["s","y","z"]="s"): + def load_from_datasets(self, + jsonfile:str, + npoles_cplx:int=2, + max_iterations:int=5, + max_points:int|None=20, + basis:type=MultiPortOrthonormalBasis, + parameter_type:Literal["s","y","z"]="s", + min_freqs:float|None=None, + max_freqs:float|None=None + ): self.parameter_type = parameter_type with open(jsonfile,"r") as f: datas = json.load(f) @@ -47,17 +59,27 @@ class GVFManager: full_freqences = network.f + min_index = 0 + max_index = len(full_freqences)-1 + for i in range(len(full_freqences)): + if min_freqs is not None and full_freqences[i] < min_freqs: + min_index = i + 1 + if max_freqs is not None and full_freqences[i] > max_freqs: + max_index = i - 1 + if parameter_type == "s": - sampled_points = network.s.reshape(-1,ports,ports) + sampled_points = network.s.reshape(-1,ports,ports)[min_index:max_index+1] elif parameter_type == "y": - sampled_points = network.y.reshape(-1,ports,ports) + sampled_points = network.y.reshape(-1,ports,ports)[min_index:max_index+1] elif parameter_type == "z": - sampled_points = network.z.reshape(-1,ports,ports) + sampled_points = network.z.reshape(-1,ports,ports)[min_index:max_index+1] + + self.full_freqs = full_freqences[min_index:max_index+1] if max_points: - H,freqs = auto_select_multple_ports(sampled_points,full_freqences,max_points=max_points) + H,freqs = auto_select_multple_ports(sampled_points,self.full_freqs,max_points=max_points) else: - H,freqs = sampled_points,full_freqences + H,freqs = sampled_points,self.full_freqs vf = VFManager(npoles_cplx=npoles_cplx,freqs=freqs,H=H,model=basis,iterations=K,verbose=False) vf.fit() @@ -93,7 +115,7 @@ class GVFManager: plot_poles_in_3d(unit, save_path) - def plot_vf_responses_with_index(self,save_dir:str,index:int): + def plot_vf_responses_with_index(self,save_dir:str,index:int,freqrange:List[float]|np.ndarray|None=None): ds = self.datasets[index] id = f"{index}" os.makedirs(save_dir,exist_ok=True) @@ -101,9 +123,12 @@ class GVFManager: vf.plot_metrics(show=False,save_path=f"{save_dir}/{id}") full_freqences = vf.freqs model_responses = vf.get_model_responses(full_freqences) - vf.plot_model_responses(show=False,save_path=f"{save_dir}/{id}") + if freqrange is None: + vf.plot_model_responses(full_freqs=self.full_freqs,show=False,save_path=f"{save_dir}/{id}") + else: + vf.plot_model_responses(full_freqs=np.array(freqrange).reshape(-1),show=False,save_path=f"{save_dir}/{id}") - def plot_all_vf_responses(self,save_dir:str): + def plot_all_vf_responses(self,save_dir:str,freqrange:List[float]|np.ndarray|None=None): for index,ds in enumerate(self.datasets): id = f"{index}" os.makedirs(save_dir,exist_ok=True) @@ -111,13 +136,25 @@ class GVFManager: vf.plot_metrics(show=False,save_path=f"{save_dir}/{id}") full_freqences = vf.freqs model_responses = vf.get_model_responses(full_freqences) - vf.plot_model_responses(show=False,save_path=f"{save_dir}/{id}") + if freqrange is None: + vf.plot_model_responses(full_freqs=self.full_freqs,show=False,save_path=f"{save_dir}/{id}") + else: + vf.plot_model_responses(full_freqs=np.array(freqrange).reshape(-1),show=False,save_path=f"{save_dir}/{id}") + + + def plot_poles_in_2d(self,save_path:str): + unit = PolesPlot2dUnit( + title="Poles Distribution 2D", + datas=[ + PolesPlot2dDataUnit( + poles=ds.vf_manager.poles, + geometries=ds.geometries + ) for ds in self.datasets + ], + x_label="Real Part", + y_label="Imaginary Part" + ) + plot_poles_in_2d(unit, save_path) - def evaluate_dataset(self): - for ds in self.datasets: - rmse = ds.vf_manager.rms_error - condition = ds.vf_manager.condition - row_condition = ds.vf_manager.row_condition - col_condition = ds.vf_manager.col_condition - + diff --git a/ovf/core/VFManager.py b/ovf/core/VFManager.py index d9b3840..775cb2f 100644 --- a/ovf/core/VFManager.py +++ b/ovf/core/VFManager.py @@ -42,6 +42,8 @@ class VFManager(): self.eigenval_row_condition = [] self.eigenval_col_condition = [] self.eigenval_rms_error = [] + self.eigenval_rms_error_l1 = [] + self.eigenval_rms_error_lmax = [] self.model_instance:MultiPortOrthonormalBasis|None = None self.model_responses_freqs = None @@ -120,6 +122,8 @@ class VFManager(): eigenval_row_condition=self.eigenval_row_condition, eigenval_col_condition=self.eigenval_col_condition, eigenval_rms_error=self.eigenval_rms_error, + eigenval_rms_error_l1=self.eigenval_rms_error_l1, + eigenval_rms_error_lmax=self.eigenval_rms_error_lmax, fit_constant=self.fit_constant, fit_proportional=self.fit_proportional, dc_enforce=self.dc_enforce, @@ -168,6 +172,8 @@ class VFManager(): instance.eigenval_condition = data["eigenval_condition"].tolist() instance.eigenval_row_condition = data["eigenval_row_condition"].tolist() instance.eigenval_col_condition = data["eigenval_col_condition"].tolist() + instance.eigenval_rms_error_l1 = data["eigenval_rms_error_l1"].tolist() + instance.eigenval_rms_error_lmax = data["eigenval_rms_error_lmax"].tolist() instance.eigenval_rms_error = data["eigenval_rms_error"].tolist() instance.fit_constant = bool(data["fit_constant"]) instance.fit_proportional = bool(data["fit_proportional"]) @@ -228,7 +234,9 @@ class VFManager(): eigenval_condition,\ eigenval_row_condition,\ eigenval_col_condition,\ - eigenval_rms_error = self.model_instance.eigen_metric + eigenval_rms_error,\ + eigenval_rms_error_l1,\ + eigenval_rms_error_lmax= self.model_instance.eigen_metric self.least_squares_condition.append(least_squares_condition) self.least_squares_row_condition.append(least_squares_row_condition) self.least_squares_col_condition.append(least_squares_col_condition) @@ -237,18 +245,20 @@ class VFManager(): self.eigenval_row_condition.append(eigenval_row_condition) self.eigenval_col_condition.append(eigenval_col_condition) self.eigenval_rms_error.append(eigenval_rms_error) + self.eigenval_rms_error_l1.append(eigenval_rms_error_l1) + self.eigenval_rms_error_lmax.append(eigenval_rms_error_lmax) return self.model_instance def plot_metrics(self,show:bool=True,save_path=None): - plt.figure(figsize=(16, 12)) - plt.subplot(4, 2, 1) + plt.figure(figsize=(20, 12)) + plt.subplot(5, 2, 1) plt.plot( range(1, len(self.least_squares_condition) + 1), self.least_squares_condition, label='Least Squares Condition' ) plt.legend() - plt.subplot(4, 2, 2) + plt.subplot(5, 2, 2) plt.plot( range(1, len(self.least_squares_row_condition) + 1), self.least_squares_row_condition, @@ -256,7 +266,7 @@ class VFManager(): ) plt.legend() - plt.subplot(4, 2, 3) + plt.subplot(5, 2, 3) plt.plot( range(1, len(self.least_squares_col_condition) + 1), self.least_squares_col_condition, @@ -264,7 +274,7 @@ class VFManager(): ) plt.legend() - plt.subplot(4, 2, 4) + plt.subplot(5, 2, 4) plt.plot( range(1, len(self.least_squares_rms_error) + 1), self.least_squares_rms_error, @@ -272,7 +282,7 @@ class VFManager(): ) plt.legend() - plt.subplot(4, 2, 5) + plt.subplot(5, 2, 5) plt.plot( range(1, len(self.eigenval_condition) + 1), self.eigenval_condition, @@ -280,7 +290,7 @@ class VFManager(): ) plt.legend() - plt.subplot(4, 2, 6) + plt.subplot(5, 2, 6) plt.plot( range(1, len(self.eigenval_row_condition) + 1), self.eigenval_row_condition, @@ -288,7 +298,7 @@ class VFManager(): ) plt.legend() - plt.subplot(4, 2, 7) + plt.subplot(5, 2, 7) plt.plot( range(1, len(self.eigenval_col_condition) + 1), self.eigenval_col_condition, @@ -296,13 +306,29 @@ class VFManager(): ) plt.legend() - plt.subplot(4, 2, 8) + plt.subplot(5, 2, 8) plt.plot( range(1, len(self.eigenval_rms_error) + 1), self.eigenval_rms_error, label='Eigenvalue RMS Error' ) plt.legend() + + plt.subplot(5, 2, 9) + plt.plot( + range(1, len(self.eigenval_rms_error_l1) + 1), + self.eigenval_rms_error_l1, + label='Eigenvalue RMS Error L1' + ) + plt.legend() + + plt.subplot(5, 2, 10) + plt.plot( + range(1, len(self.eigenval_rms_error_lmax) + 1), + self.eigenval_rms_error_lmax, + label='Eigenvalue RMS Error Lmax' + ) + plt.legend() if show: plt.show() @@ -312,33 +338,37 @@ class VFManager(): os.makedirs(save_path, exist_ok=True) plt.savefig(f"{save_path}/fitting_metrics.png") - def plot_model_responses(self,show:bool=True,save_path=None): + def plot_model_responses(self,full_freqs,show:bool=True,save_path=None): assert self.model_responses_freqs is not None and self.model_responses_H is not None, "Please run get_model_responses() first." + + full_freqs = np.array(full_freqs) + model_responses = self.get_model_responses(full_freqs) + for i in range(self.nports): for j in range(self.nports): plt.figure(figsize=(12, 6)) plt.subplot(2, 2, 1) plt.plot(self.freqs, np.abs(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples') - plt.plot(self.model_responses_freqs, np.abs(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit') + plt.plot(full_freqs, np.abs(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit') plt.title(f"Response i={i+1}, j={j+1}") plt.ylabel("Magnitude") plt.legend(loc="best") plt.subplot(2, 2, 2) plt.plot(self.freqs, np.angle(self.H[:,i,j],deg=True), 'o', ms=4, color='red', label='Input Samples') - plt.plot(self.model_responses_freqs, np.angle(self.model_responses_H[:,i,j],deg=True), '-', lw=2, color='k', label='Fit') + plt.plot(full_freqs, np.angle(model_responses[:,i,j],deg=True), '-', lw=2, color='k', label='Fit') plt.title(f"Response i={i+1}, j={j+1}") plt.ylabel("Phase (deg)") plt.legend(loc="best") plt.tight_layout() plt.subplot(2, 2, 3) plt.plot(self.freqs, np.real(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples') - plt.plot(self.model_responses_freqs, np.real(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit') + plt.plot(full_freqs, np.real(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit') plt.title(f"Response i={i+1}, j={j+1}") plt.ylabel("Real Part") plt.legend(loc="best") plt.subplot(2, 2, 4) plt.plot(self.freqs, np.imag(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples') - plt.plot(self.model_responses_freqs, np.imag(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit') + plt.plot(full_freqs, np.imag(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit') plt.title(f"Response i={i+1}, j={j+1}") plt.ylabel("Imag Part") plt.legend(loc="best") diff --git a/ovf/core/basis/MultiPortOrthonormalBasis.py b/ovf/core/basis/MultiPortOrthonormalBasis.py index 6275083..f6dd242 100644 --- a/ovf/core/basis/MultiPortOrthonormalBasis.py +++ b/ovf/core/basis/MultiPortOrthonormalBasis.py @@ -13,6 +13,8 @@ class MultiPortOrthonormalBasis: self._eigenval_row_condition = None self._eigenval_col_condition = None self._eigenval_rms_error = None + self._eigenval_rms_error_l1 = None + self._eigenval_rms_error_lmax = None self.Cr = None self.dc_tol = 1e-18 @@ -28,8 +30,8 @@ class MultiPortOrthonormalBasis: self.s = self.freqs * 2j * np.pi self.pre_poles = pre_poles - self.Phi = self.generate_basis(self.s, self.pre_poles) - self.A = self.matrix_A(self.pre_poles) + self.Phi = self.generate_basis_confluent(self.s, self.pre_poles) + self.A = self.matrix_A_confluent(self.pre_poles) self.B = self.vector_B(self.pre_poles) self.C,self.w0,self.e = self.fit_denominator(self.H) self.D = self.w0 @@ -65,8 +67,15 @@ class MultiPortOrthonormalBasis: self._eigenval_condition,\ self._eigenval_row_condition,\ self._eigenval_col_condition,\ - self._eigenval_rms_error = self._eigen_metric() - return self._eigenval_condition, self._eigenval_row_condition, self._eigenval_col_condition, self._eigenval_rms_error + self._eigenval_rms_error,\ + self._eigenval_rms_error_l1,\ + self._eigenval_rms_error_lmax= self._eigen_metric() + return self._eigenval_condition,\ + self._eigenval_row_condition,\ + self._eigenval_col_condition,\ + self._eigenval_rms_error,\ + self._eigenval_rms_error_l1,\ + self._eigenval_rms_error_lmax @property @@ -101,12 +110,15 @@ class MultiPortOrthonormalBasis: Hk = self.H Hk_fitted = self.get_model_responses(self.freqs) - rms = np.sqrt(np.mean(np.abs(Hk - Hk_fitted)**2)) + rms = np.sqrt(np.mean(np.abs(Hk - Hk_fitted)**2)) / np.sqrt(np.mean(np.abs(Hk)**2)) + + rms_l1 = np.mean(np.abs(Hk - Hk_fitted)) / np.mean(np.abs(Hk)) + rms_lmax = np.max(np.abs(Hk - Hk_fitted)) / np.max(np.abs(Hk)) row_cond = cond_row_inf(self.A - self.B @ self.C) col_cond = cond_col_one(self.A - self.B @ self.C) - return cond,row_cond,col_cond,rms + return cond,row_cond,col_cond,rms,rms_l1,rms_lmax def _least_squares_metric(self,Q,R,x,b): """Return condition number and RMS error of least-squares matrix A and rhs b.""" @@ -174,6 +186,187 @@ class MultiPortOrthonormalBasis: ap_list.append(ap) Phi = np.column_stack(cols).astype(np.complex128) return Phi + + @staticmethod + def _group_poles_confluent(poles, tol=1e-12): + """ + 把 LHP 极点数组 `poles`(复数,含共轭对)按重复进行分组。 + 返回一个 list,每项: + {"type":"real","ap":ap_RHP,"m":m} + {"type":"pair","ap":ap_RHP,"ap1":ap_RHP_conj,"m":m} + 其中 ap = -p(RHP,Re>0),m 是重数。 + """ + z = np.asarray(poles, np.complex128).copy() + used = np.zeros(len(z), dtype=bool) + groups = [] + i = 0 + while i < len(z): + if used[i]: + i += 1 + continue + p = z[i] + ap = -p # 你的代码里用 ap = -p(RHP) + if ap.real < 0: + raise ValueError("poles must be in the LHP") + if np.isclose(p.imag, 0.0, atol=tol): + # 实极点:统计重数 + mask = (~used) & np.isclose(z.real, p.real, atol=tol) & np.isclose(z.imag, p.imag, atol=tol) + idx = np.where(mask)[0] + used[idx] = True + groups.append({"type": "real", "ap": ap, "m": len(idx)}) + else: + # 找共轭 + j = None + for k in range(i+1, len(z)): + if used[k]: + continue + if np.isclose(z[k], np.conj(p), atol=tol): + j = k; break + if j is None: + # 容错:没配上就当作单个复极点(很罕见);按一对处理也可 + used[i] = True + groups.append({"type": "real", "ap": ap, "m": 1}) + else: + # 统计这对的重数 m + used[i] = used[j] = True + m = 1 + # 继续往后找更多相同位置的对 + while True: + # 在未使用的剩余里各找一枚匹配 + ii = None; jj = None + for k in range(len(z)): + if not used[k] and np.isclose(z[k], p, atol=tol): + ii = k; break + for k in range(len(z)): + if not used[k] and np.isclose(z[k], np.conj(p), atol=tol): + jj = k; break + if ii is not None and jj is not None: + used[ii] = used[jj] = True + m += 1 + else: + break + ap1 = -np.conj(p) # = conj(ap) + groups.append({"type":"pair", "ap": ap, "ap1": ap1, "m": m}) + i += 1 + return groups + def generate_basis_confluent(self,s, poles, tol=1e-12): + """ + 合流版正交基(与原 generate_basis 接口兼容)。 + 返回: + Phi : (K, ncol) complex + layout : 列描述(便于后续打包 C) + """ + def _all_pass_factor(s, ap): + # ((s - ap*)/(s + ap)),在 s=jw 上幅度为1 + return (s - np.conj(ap)) / (s + ap) + + def _all_pass_cascade(s, ap_list): + res = 1.0 + 0.0j + for ap in ap_list: + res *= _all_pass_factor(s, ap) + return res + + s = np.asarray(s, np.complex128).reshape(-1) + groups = self._group_poles_confluent(poles, tol=tol) + cols = [] + layout = [] + prev_aps = [] # 进入当前组之前的 all-pass 级联清单(含重复) + Gprev = _all_pass_cascade(s, prev_aps) + + for g in groups: + if g["type"] == "real": + ap, m = g["ap"], g["m"] + # 基函数模板:sqrt(2 Re ap) * Gprev * [(AP(ap))^r] * 1/(s+ap) + base = 1.0/(s + ap) + AP = _all_pass_factor(s, ap) # 当前极点的 all-pass + if m > 1: + pass + for r in range(m): + phi = np.sqrt(2.0*ap.real) * Gprev * (AP**r) * base + cols.append(phi) + layout.append({"type":"real", "ap":ap, "order":r+1}) + # 更新 prev_aps & Gprev(把这一组的 m 个 ap 都加入) + for _ in range(m): + prev_aps.append(ap) + Gprev *= _all_pass_factor(s, ap) + + elif g["type"] == "pair": + ap, ap1, m = g["ap"], g["ap1"], g["m"] + # 你的原始两列:phi1, phi2(r=0) + # 合流:在它们前乘 [(AP(ap)*AP(ap1))^r] + denom = (s + ap)*(s + ap1) + absap = np.abs(ap) + base1 = (s - absap)/denom + base2 = (s + absap)/denom + APpair = _all_pass_factor(s, ap) * _all_pass_factor(s, ap1) # 成对 all-pass + if m > 1: + pass + for r in range(m): + Gr = Gprev * (APpair**r) + phi1 = np.sqrt(2.0*ap.real) * Gr * base1 + phi2 = np.sqrt(2.0*ap.real) * Gr * base2 + cols += [phi1, phi2] + layout += [{"type":"pair_Re","ap":ap,"ap1":ap1,"order":r+1}, + {"type":"pair_Im","ap":ap,"ap1":ap1,"order":r+1}] + # 更新 prev_aps & Gprev(把该对重复 m 次都加入) + for _ in range(m): + prev_aps.append(ap); Gprev *= _all_pass_factor(s, ap) + prev_aps.append(ap1); Gprev *= _all_pass_factor(s, ap1) + + else: + raise RuntimeError("unknown group type") + + Phi = np.column_stack(cols).astype(np.complex128) + return Phi + + def matrix_A_confluent(self,poles, tol=1e-12): + groups = self._group_poles_confluent(poles, tol=tol) + + # 先计算总列数 N(=基函数数) + N = 0 + for g in groups: + if g["type"]=="real": + N += g["m"] + else: # pair + N += 2*g["m"] + + A = np.zeros((N, N), float) + + col_ptr = 0 # 当前要写入的列起始索引 + for g in groups: + if g["type"] == "real": + ap, m = g["ap"], g["m"] + if m > 1: + pass + a = ap.real + for r in range(m): + c = col_ptr # 当前这一列的全局索引 + # 对角项 + A[c, c] = -a + # 尾部常数 2a:对所有后续行 + if c+1 < N: + A[c+1:, c] += 2.0*a + col_ptr += 1 + + elif g["type"] == "pair": + ap, ap1, m = g["ap"], g["ap1"], g["m"] + a = ap.real + absap = float(np.abs(ap)) + for r in range(m): + c0 = col_ptr # 这一对的两列 c0, c0+1 + # 2×2 核心块(放在对角) + A[c0:c0+2, c0:c0+2] = np.array([[-a, -a-absap], + [-a+absap, -a]], float) + # 尾部常数 2a:对所有后续行(逐行加到这两列) + if c0+2 < N: + A[c0+2:, c0] += 2.0*a + A[c0+2:, c0+1] += 2.0*a + col_ptr += 2 + + else: + raise RuntimeError("unknown group type") + + return A def matrix_A(self, poles): def A_col(p:np.complex128,index:int): diff --git a/ovf/core/geometry/plot2d.py b/ovf/core/geometry/plot2d.py index 4d825f5..51c23e3 100644 --- a/ovf/core/geometry/plot2d.py +++ b/ovf/core/geometry/plot2d.py @@ -5,28 +5,43 @@ from sklearn.linear_model import Ridge import os from ovf.schemas.geometry.basic import GeometryPlot2dUnit, GeometryPlot2dComplexDataUnit -def plot_2d(unit: GeometryPlot2dUnit, dirname: str): - os.makedirs(dirname, exist_ok=True) - - fig = make_subplots(rows=1, cols=1) +def plot_2d(unit: GeometryPlot2dUnit, filename: str): + os.makedirs(os.path.dirname(filename), exist_ok=True) + scatters_data = [] + + colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'] for index in range(len(unit.datas[0].x)): x = [unit.datas[i].x[index] for i in range(len(unit.datas))] y = [unit.datas[i].y[index] for i in range(len(unit.datas))] - fig.add_trace( - go.Scatter( - x=x, - y=y, - mode='markers', - name=data.label - ), - row=1, col=1 - ) - - fig.update_layout( - title=unit.title, - xaxis_title=unit.x_label, - yaxis_title=unit.y_label, + + for pt in range(len(unit.datas)): + scatters_data.append( + go.Scatter( + x=[x[pt]], + y=[y[pt]], + mode='markers', + name=None, + marker=dict(size=8, color=colors[index%len(colors)]), + hovertext=f"Geometries: {unit.datas[pt].geometries}
X: {x[pt]}
Y: {y[pt]}" + ) + ) + layout = go.Layout( + title=f'{unit.title}', + xaxis=dict(title=f'{unit.x_label}'), + yaxis=dict(title=f'{unit.y_label}'), + showlegend=False ) - - fig.write_html(f"{dirname}/plot_2d.html") \ No newline at end of file + + fig = go.Figure(data=scatters_data, layout=layout) + + if filename.endswith('.html'): + fig.write_html(filename) + elif filename.endswith('.png'): + fig.write_image(filename,format='png') + elif filename.endswith('.pdf'): + fig.write_image(filename,format='pdf') + elif filename.endswith('.jpeg') or filename.endswith('.jpg'): + fig.write_image(filename,format='jpeg') + else: + fig.write_html(f"{filename}.html") \ No newline at end of file diff --git a/ovf/core/geometry/plot_poles.py b/ovf/core/geometry/plot_poles.py index 0d6c57c..fb904c6 100644 --- a/ovf/core/geometry/plot_poles.py +++ b/ovf/core/geometry/plot_poles.py @@ -3,9 +3,10 @@ import plotly.graph_objs as go from plotly.offline import plot import numpy as np from sklearn.linear_model import Ridge -from ovf.schemas.geometry.poles import PolesPlot3dUnit, PolesPlot3dDataUnit -from ovf.schemas.geometry.basic import GeometryPlot3dDataUnit, GeometryPlot3dUnit +from ovf.schemas.geometry.poles import PolesPlot3dUnit, PolesPlot3dDataUnit, PolesPlot2dUnit, PolesPlot2dDataUnit +from ovf.schemas.geometry.basic import GeometryPlot3dDataUnit, GeometryPlot3dUnit,GeometryPlot2dUnit,GeometryPlot2dComplexDataUnit from ovf.core.geometry.plot3d import get_plot_instance_in_3d +from ovf.core.geometry.plot2d import plot_2d from plotly.subplots import make_subplots import os @@ -85,4 +86,16 @@ def plot_poles_in_3d(poles:PolesPlot3dUnit,dirname:str): index += 1 +def plot_poles_in_2d(poles:PolesPlot2dUnit,filename:str): + data = GeometryPlot2dUnit( + title=poles.title, + datas=[GeometryPlot2dComplexDataUnit( + x=[np.real(p) for p in d.poles], + y=[np.imag(p) for p in d.poles], + geometries=d.geometries + ) for d in poles.datas], + x_label=poles.x_label, + y_label=poles.y_label + ) + plot_2d(data, filename) \ No newline at end of file diff --git a/ovf/schemas/geometry/basic.py b/ovf/schemas/geometry/basic.py index d53a0a8..0396fce 100644 --- a/ovf/schemas/geometry/basic.py +++ b/ovf/schemas/geometry/basic.py @@ -25,7 +25,7 @@ class GeometryPlot3dUnit: class GeometryPlot2dComplexDataUnit: x: List[float] y: List[float] - geometries: Dict[str,Union[float,int,str]] + geometries: Dict[str,float] @dataclass diff --git a/ovf/schemas/geometry/poles.py b/ovf/schemas/geometry/poles.py index 3301596..f9bd430 100644 --- a/ovf/schemas/geometry/poles.py +++ b/ovf/schemas/geometry/poles.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Literal, Union +from typing import List, Literal, Union, Dict import numpy as np @dataclass @@ -27,4 +27,16 @@ class PolesPlot3dUnit: geometry_2_label: str = "geometry_2" pole_scale: Literal["linear", "log"] = "linear" geometry_1_scale: Literal["linear", "log"] = "linear" - geometry_2_scale: Literal["linear", "log"] = "linear" \ No newline at end of file + geometry_2_scale: Literal["linear", "log"] = "linear" + +@dataclass +class PolesPlot2dDataUnit: + poles: List[Union[complex,np.complex128]] + geometries: Dict[str,float] + +@dataclass +class PolesPlot2dUnit: + title: str + datas: List[PolesPlot2dDataUnit] + x_label: str + y_label: str \ No newline at end of file