diff --git a/docs/poles_match.md b/docs/poles_match.md
new file mode 100644
index 0000000..4e6eaed
--- /dev/null
+++ b/docs/poles_match.md
@@ -0,0 +1,65 @@
+下面把“极点识别→一对一配对”作为线性指派问题的优化目标,用公式列出(含常用代价设计与未匹配处理)。
+
+记号
+- 集合 A(参考/上一时刻/阶次 k):第 i 个模式的频率、阻尼、振型为 $f_i, ζ_i, φ_i$
+- 集合 B(估计/下一时刻/阶次 k+1):第 j 个模式为 $f'_j, ζ'_j, φ'_j$
+- 模态相似度 MAC:
+$$
+\mathrm{MAC}(\phi_i,\phi'_j)=\frac{\bigl|\phi_i^{H}\phi'_j\bigr|^2}{(\phi_i^{H}\phi_i)\,(\phi'_j{}^{H}\phi'_j)}\in[0,1]
+$$
+
+一、配对代价 C_ij 的常用设计
+- 加权和(绝对差版本):
+$$
+C_{ij} = w_f\frac{|f_i-f'_j|}{f_{\mathrm{scale}}} + w_\zeta\,|\,\zeta_i-\zeta'_j\,|+ w_{\mathrm{mac}}\bigl(1-\mathrm{MAC}_{ij}\bigr)
+$$
+- 或加权平方差版本:
+$$
+C_{ij}= w_f\Bigl(\frac{f_i-f'_j}{f_{\mathrm{scale}}}\Bigr)^2+ w_\zeta\bigl(\zeta_i-\zeta'_j\bigr)^2+ w_{\mathrm{mac}}\bigl(1-\mathrm{MAC}_{ij}\bigr)^2
+$$
+- 门控(Big-M 罚):
+$$
+C_{ij}\;:=\;
+\begin{cases}
+C_{ij}, & \frac{|f_i-f'_j|}{f_{\mathrm{scale}}}\le\tau_f,\;\;|\,\zeta_i-\zeta'_j\,|\le\tau_\zeta,\;\;\mathrm{MAC}_{ij}\ge\tau_{\mathrm{mac}}\\[4pt]
+M, & \text{否则}
+\end{cases}
+\qquad(M\text{ 很大})
+$$
+
+二、允许“未匹配”的指派目标(显式未匹配变量)
+令决策变量 x_{ij}\in\{0,1\} 表示 i 是否配给 j;u_i,v_j 表示未匹配。
+$$
+\min_{x,u,v}\;
+\sum_{i=1}^{m}\sum_{j=1}^{n} C_{ij}\,x_{ij}
++\tau_A\sum_{i=1}^{m} u_i
++\tau_B\sum_{j=1}^{n} v_j
+$$
+约束:
+$$
+\sum_{j=1}^{n} x_{ij}+u_i=1,\;\; \forall i=1,\dots,m
+\qquad
+\sum_{i=1}^{m} x_{ij}+v_j=1,\;\; \forall j=1,\dots,n
+$$
+$$
+x_{ij}\in\{0,1\},\;\;u_i\in\{0,1\},\;\;v_j\in\{0,1\}
+$$
+
+三、允许“未匹配”的等价“补方阵”形式(虚拟结点)
+取 L=\max(m,n),把代价矩阵补成
+$\tilde C\in\mathbb{R}^{L\times L}$,对虚拟匹配赋常数代价 $C_{\mathrm{unmatch}}$,解标准 LSAP:
+$$
+\min_{X\in\{0,1\}^{L\times L}}\;\langle \tilde C, X\rangle
+\quad\text{s.t.}\quad
+X\mathbf{1}=\mathbf{1},\;\; \mathbf{1}^\top X=\mathbf{1}^\top
+$$
+其中 X 为置换矩阵;落在补齐的虚拟行/列上即表示“未匹配”,其代价由 $C_{\mathrm{unmatch}}$ 决定。
+
+四、最大化相似度的等价形式
+若用相似度矩阵 $S_{ij}$(如 $S_{ij}= \alpha_{\mathrm{mac}}\mathrm{MAC}_{ij}-\alpha_f\frac{|f_i-f'_j|}{f_{\mathrm{scale}}}-\alpha_\zeta|\,\zeta_i-\zeta'_j\,|$),则:
+$$
+\max_{X}\;\sum_{ij} S_{ij} X_{ij}
+\quad\Longleftrightarrow\quad
+\min_{X}\;\sum_{ij} (-S_{ij}) X_{ij}
+$$
+并沿用上面的指派约束与“未匹配”处理。
\ No newline at end of file
diff --git a/ovf/core/GVFManager.py b/ovf/core/GVFManager.py
index addcb41..648fe6a 100644
--- a/ovf/core/GVFManager.py
+++ b/ovf/core/GVFManager.py
@@ -9,6 +9,8 @@ from ovf.core.sample import auto_select_multple_ports
from ovf.core.basis.MultiPortOrthonormalBasis import MultiPortOrthonormalBasis
from ovf.core.geometry.plot_poles import plot_poles_in_3d
from ovf.schemas.geometry.poles import PolesPlot3dUnit,PolesPlot3dDataUnit
+from ovf.schemas.geometry.poles import PolesPlot2dUnit,PolesPlot2dDataUnit
+from ovf.core.geometry.plot_poles import plot_poles_in_2d
import pickle
@dataclass
@@ -21,6 +23,7 @@ class GVFManager:
def __init__(self):
self.parameter_type: Literal["s","y","z"] = "s"
self.datasets: List[GVFDataUnit] = []
+ self.full_freqs: np.ndarray | None = None
def save(self,filepath:str):
os.makedirs(os.path.dirname(filepath),exist_ok=True)
@@ -34,7 +37,16 @@ class GVFManager:
assert isinstance(obj,GVFManager),"The loaded object is not a GVFManager instance."
return obj
- def load_from_datasets(self,jsonfile:str,npoles_cplx:int=2,max_iterations:int=5,max_points:int|None=20,basis:type=MultiPortOrthonormalBasis,parameter_type:Literal["s","y","z"]="s"):
+ def load_from_datasets(self,
+ jsonfile:str,
+ npoles_cplx:int=2,
+ max_iterations:int=5,
+ max_points:int|None=20,
+ basis:type=MultiPortOrthonormalBasis,
+ parameter_type:Literal["s","y","z"]="s",
+ min_freqs:float|None=None,
+ max_freqs:float|None=None
+ ):
self.parameter_type = parameter_type
with open(jsonfile,"r") as f:
datas = json.load(f)
@@ -47,17 +59,27 @@ class GVFManager:
full_freqences = network.f
+ min_index = 0
+ max_index = len(full_freqences)-1
+ for i in range(len(full_freqences)):
+ if min_freqs is not None and full_freqences[i] < min_freqs:
+ min_index = i + 1
+ if max_freqs is not None and full_freqences[i] > max_freqs:
+ max_index = i - 1
+
if parameter_type == "s":
- sampled_points = network.s.reshape(-1,ports,ports)
+ sampled_points = network.s.reshape(-1,ports,ports)[min_index:max_index+1]
elif parameter_type == "y":
- sampled_points = network.y.reshape(-1,ports,ports)
+ sampled_points = network.y.reshape(-1,ports,ports)[min_index:max_index+1]
elif parameter_type == "z":
- sampled_points = network.z.reshape(-1,ports,ports)
+ sampled_points = network.z.reshape(-1,ports,ports)[min_index:max_index+1]
+
+ self.full_freqs = full_freqences[min_index:max_index+1]
if max_points:
- H,freqs = auto_select_multple_ports(sampled_points,full_freqences,max_points=max_points)
+ H,freqs = auto_select_multple_ports(sampled_points,self.full_freqs,max_points=max_points)
else:
- H,freqs = sampled_points,full_freqences
+ H,freqs = sampled_points,self.full_freqs
vf = VFManager(npoles_cplx=npoles_cplx,freqs=freqs,H=H,model=basis,iterations=K,verbose=False)
vf.fit()
@@ -93,7 +115,7 @@ class GVFManager:
plot_poles_in_3d(unit, save_path)
- def plot_vf_responses_with_index(self,save_dir:str,index:int):
+ def plot_vf_responses_with_index(self,save_dir:str,index:int,freqrange:List[float]|np.ndarray|None=None):
ds = self.datasets[index]
id = f"{index}"
os.makedirs(save_dir,exist_ok=True)
@@ -101,9 +123,12 @@ class GVFManager:
vf.plot_metrics(show=False,save_path=f"{save_dir}/{id}")
full_freqences = vf.freqs
model_responses = vf.get_model_responses(full_freqences)
- vf.plot_model_responses(show=False,save_path=f"{save_dir}/{id}")
+ if freqrange is None:
+ vf.plot_model_responses(full_freqs=self.full_freqs,show=False,save_path=f"{save_dir}/{id}")
+ else:
+ vf.plot_model_responses(full_freqs=np.array(freqrange).reshape(-1),show=False,save_path=f"{save_dir}/{id}")
- def plot_all_vf_responses(self,save_dir:str):
+ def plot_all_vf_responses(self,save_dir:str,freqrange:List[float]|np.ndarray|None=None):
for index,ds in enumerate(self.datasets):
id = f"{index}"
os.makedirs(save_dir,exist_ok=True)
@@ -111,13 +136,25 @@ class GVFManager:
vf.plot_metrics(show=False,save_path=f"{save_dir}/{id}")
full_freqences = vf.freqs
model_responses = vf.get_model_responses(full_freqences)
- vf.plot_model_responses(show=False,save_path=f"{save_dir}/{id}")
+ if freqrange is None:
+ vf.plot_model_responses(full_freqs=self.full_freqs,show=False,save_path=f"{save_dir}/{id}")
+ else:
+ vf.plot_model_responses(full_freqs=np.array(freqrange).reshape(-1),show=False,save_path=f"{save_dir}/{id}")
+
+
+ def plot_poles_in_2d(self,save_path:str):
+ unit = PolesPlot2dUnit(
+ title="Poles Distribution 2D",
+ datas=[
+ PolesPlot2dDataUnit(
+ poles=ds.vf_manager.poles,
+ geometries=ds.geometries
+ ) for ds in self.datasets
+ ],
+ x_label="Real Part",
+ y_label="Imaginary Part"
+ )
+ plot_poles_in_2d(unit, save_path)
- def evaluate_dataset(self):
- for ds in self.datasets:
- rmse = ds.vf_manager.rms_error
- condition = ds.vf_manager.condition
- row_condition = ds.vf_manager.row_condition
- col_condition = ds.vf_manager.col_condition
-
+
diff --git a/ovf/core/VFManager.py b/ovf/core/VFManager.py
index d9b3840..775cb2f 100644
--- a/ovf/core/VFManager.py
+++ b/ovf/core/VFManager.py
@@ -42,6 +42,8 @@ class VFManager():
self.eigenval_row_condition = []
self.eigenval_col_condition = []
self.eigenval_rms_error = []
+ self.eigenval_rms_error_l1 = []
+ self.eigenval_rms_error_lmax = []
self.model_instance:MultiPortOrthonormalBasis|None = None
self.model_responses_freqs = None
@@ -120,6 +122,8 @@ class VFManager():
eigenval_row_condition=self.eigenval_row_condition,
eigenval_col_condition=self.eigenval_col_condition,
eigenval_rms_error=self.eigenval_rms_error,
+ eigenval_rms_error_l1=self.eigenval_rms_error_l1,
+ eigenval_rms_error_lmax=self.eigenval_rms_error_lmax,
fit_constant=self.fit_constant,
fit_proportional=self.fit_proportional,
dc_enforce=self.dc_enforce,
@@ -168,6 +172,8 @@ class VFManager():
instance.eigenval_condition = data["eigenval_condition"].tolist()
instance.eigenval_row_condition = data["eigenval_row_condition"].tolist()
instance.eigenval_col_condition = data["eigenval_col_condition"].tolist()
+ instance.eigenval_rms_error_l1 = data["eigenval_rms_error_l1"].tolist()
+ instance.eigenval_rms_error_lmax = data["eigenval_rms_error_lmax"].tolist()
instance.eigenval_rms_error = data["eigenval_rms_error"].tolist()
instance.fit_constant = bool(data["fit_constant"])
instance.fit_proportional = bool(data["fit_proportional"])
@@ -228,7 +234,9 @@ class VFManager():
eigenval_condition,\
eigenval_row_condition,\
eigenval_col_condition,\
- eigenval_rms_error = self.model_instance.eigen_metric
+ eigenval_rms_error,\
+ eigenval_rms_error_l1,\
+ eigenval_rms_error_lmax= self.model_instance.eigen_metric
self.least_squares_condition.append(least_squares_condition)
self.least_squares_row_condition.append(least_squares_row_condition)
self.least_squares_col_condition.append(least_squares_col_condition)
@@ -237,18 +245,20 @@ class VFManager():
self.eigenval_row_condition.append(eigenval_row_condition)
self.eigenval_col_condition.append(eigenval_col_condition)
self.eigenval_rms_error.append(eigenval_rms_error)
+ self.eigenval_rms_error_l1.append(eigenval_rms_error_l1)
+ self.eigenval_rms_error_lmax.append(eigenval_rms_error_lmax)
return self.model_instance
def plot_metrics(self,show:bool=True,save_path=None):
- plt.figure(figsize=(16, 12))
- plt.subplot(4, 2, 1)
+ plt.figure(figsize=(20, 12))
+ plt.subplot(5, 2, 1)
plt.plot(
range(1, len(self.least_squares_condition) + 1),
self.least_squares_condition,
label='Least Squares Condition'
)
plt.legend()
- plt.subplot(4, 2, 2)
+ plt.subplot(5, 2, 2)
plt.plot(
range(1, len(self.least_squares_row_condition) + 1),
self.least_squares_row_condition,
@@ -256,7 +266,7 @@ class VFManager():
)
plt.legend()
- plt.subplot(4, 2, 3)
+ plt.subplot(5, 2, 3)
plt.plot(
range(1, len(self.least_squares_col_condition) + 1),
self.least_squares_col_condition,
@@ -264,7 +274,7 @@ class VFManager():
)
plt.legend()
- plt.subplot(4, 2, 4)
+ plt.subplot(5, 2, 4)
plt.plot(
range(1, len(self.least_squares_rms_error) + 1),
self.least_squares_rms_error,
@@ -272,7 +282,7 @@ class VFManager():
)
plt.legend()
- plt.subplot(4, 2, 5)
+ plt.subplot(5, 2, 5)
plt.plot(
range(1, len(self.eigenval_condition) + 1),
self.eigenval_condition,
@@ -280,7 +290,7 @@ class VFManager():
)
plt.legend()
- plt.subplot(4, 2, 6)
+ plt.subplot(5, 2, 6)
plt.plot(
range(1, len(self.eigenval_row_condition) + 1),
self.eigenval_row_condition,
@@ -288,7 +298,7 @@ class VFManager():
)
plt.legend()
- plt.subplot(4, 2, 7)
+ plt.subplot(5, 2, 7)
plt.plot(
range(1, len(self.eigenval_col_condition) + 1),
self.eigenval_col_condition,
@@ -296,13 +306,29 @@ class VFManager():
)
plt.legend()
- plt.subplot(4, 2, 8)
+ plt.subplot(5, 2, 8)
plt.plot(
range(1, len(self.eigenval_rms_error) + 1),
self.eigenval_rms_error,
label='Eigenvalue RMS Error'
)
plt.legend()
+
+ plt.subplot(5, 2, 9)
+ plt.plot(
+ range(1, len(self.eigenval_rms_error_l1) + 1),
+ self.eigenval_rms_error_l1,
+ label='Eigenvalue RMS Error L1'
+ )
+ plt.legend()
+
+ plt.subplot(5, 2, 10)
+ plt.plot(
+ range(1, len(self.eigenval_rms_error_lmax) + 1),
+ self.eigenval_rms_error_lmax,
+ label='Eigenvalue RMS Error Lmax'
+ )
+ plt.legend()
if show:
plt.show()
@@ -312,33 +338,37 @@ class VFManager():
os.makedirs(save_path, exist_ok=True)
plt.savefig(f"{save_path}/fitting_metrics.png")
- def plot_model_responses(self,show:bool=True,save_path=None):
+ def plot_model_responses(self,full_freqs,show:bool=True,save_path=None):
assert self.model_responses_freqs is not None and self.model_responses_H is not None, "Please run get_model_responses() first."
+
+ full_freqs = np.array(full_freqs)
+ model_responses = self.get_model_responses(full_freqs)
+
for i in range(self.nports):
for j in range(self.nports):
plt.figure(figsize=(12, 6))
plt.subplot(2, 2, 1)
plt.plot(self.freqs, np.abs(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
- plt.plot(self.model_responses_freqs, np.abs(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
+ plt.plot(full_freqs, np.abs(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.title(f"Response i={i+1}, j={j+1}")
plt.ylabel("Magnitude")
plt.legend(loc="best")
plt.subplot(2, 2, 2)
plt.plot(self.freqs, np.angle(self.H[:,i,j],deg=True), 'o', ms=4, color='red', label='Input Samples')
- plt.plot(self.model_responses_freqs, np.angle(self.model_responses_H[:,i,j],deg=True), '-', lw=2, color='k', label='Fit')
+ plt.plot(full_freqs, np.angle(model_responses[:,i,j],deg=True), '-', lw=2, color='k', label='Fit')
plt.title(f"Response i={i+1}, j={j+1}")
plt.ylabel("Phase (deg)")
plt.legend(loc="best")
plt.tight_layout()
plt.subplot(2, 2, 3)
plt.plot(self.freqs, np.real(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
- plt.plot(self.model_responses_freqs, np.real(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
+ plt.plot(full_freqs, np.real(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.title(f"Response i={i+1}, j={j+1}")
plt.ylabel("Real Part")
plt.legend(loc="best")
plt.subplot(2, 2, 4)
plt.plot(self.freqs, np.imag(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
- plt.plot(self.model_responses_freqs, np.imag(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
+ plt.plot(full_freqs, np.imag(model_responses[:,i,j]), '-', lw=2, color='k', label='Fit')
plt.title(f"Response i={i+1}, j={j+1}")
plt.ylabel("Imag Part")
plt.legend(loc="best")
diff --git a/ovf/core/basis/MultiPortOrthonormalBasis.py b/ovf/core/basis/MultiPortOrthonormalBasis.py
index 6275083..f6dd242 100644
--- a/ovf/core/basis/MultiPortOrthonormalBasis.py
+++ b/ovf/core/basis/MultiPortOrthonormalBasis.py
@@ -13,6 +13,8 @@ class MultiPortOrthonormalBasis:
self._eigenval_row_condition = None
self._eigenval_col_condition = None
self._eigenval_rms_error = None
+ self._eigenval_rms_error_l1 = None
+ self._eigenval_rms_error_lmax = None
self.Cr = None
self.dc_tol = 1e-18
@@ -28,8 +30,8 @@ class MultiPortOrthonormalBasis:
self.s = self.freqs * 2j * np.pi
self.pre_poles = pre_poles
- self.Phi = self.generate_basis(self.s, self.pre_poles)
- self.A = self.matrix_A(self.pre_poles)
+ self.Phi = self.generate_basis_confluent(self.s, self.pre_poles)
+ self.A = self.matrix_A_confluent(self.pre_poles)
self.B = self.vector_B(self.pre_poles)
self.C,self.w0,self.e = self.fit_denominator(self.H)
self.D = self.w0
@@ -65,8 +67,15 @@ class MultiPortOrthonormalBasis:
self._eigenval_condition,\
self._eigenval_row_condition,\
self._eigenval_col_condition,\
- self._eigenval_rms_error = self._eigen_metric()
- return self._eigenval_condition, self._eigenval_row_condition, self._eigenval_col_condition, self._eigenval_rms_error
+ self._eigenval_rms_error,\
+ self._eigenval_rms_error_l1,\
+ self._eigenval_rms_error_lmax= self._eigen_metric()
+ return self._eigenval_condition,\
+ self._eigenval_row_condition,\
+ self._eigenval_col_condition,\
+ self._eigenval_rms_error,\
+ self._eigenval_rms_error_l1,\
+ self._eigenval_rms_error_lmax
@property
@@ -101,12 +110,15 @@ class MultiPortOrthonormalBasis:
Hk = self.H
Hk_fitted = self.get_model_responses(self.freqs)
- rms = np.sqrt(np.mean(np.abs(Hk - Hk_fitted)**2))
+ rms = np.sqrt(np.mean(np.abs(Hk - Hk_fitted)**2)) / np.sqrt(np.mean(np.abs(Hk)**2))
+
+ rms_l1 = np.mean(np.abs(Hk - Hk_fitted)) / np.mean(np.abs(Hk))
+ rms_lmax = np.max(np.abs(Hk - Hk_fitted)) / np.max(np.abs(Hk))
row_cond = cond_row_inf(self.A - self.B @ self.C)
col_cond = cond_col_one(self.A - self.B @ self.C)
- return cond,row_cond,col_cond,rms
+ return cond,row_cond,col_cond,rms,rms_l1,rms_lmax
def _least_squares_metric(self,Q,R,x,b):
"""Return condition number and RMS error of least-squares matrix A and rhs b."""
@@ -174,6 +186,187 @@ class MultiPortOrthonormalBasis:
ap_list.append(ap)
Phi = np.column_stack(cols).astype(np.complex128)
return Phi
+
+ @staticmethod
+ def _group_poles_confluent(poles, tol=1e-12):
+ """
+ 把 LHP 极点数组 `poles`(复数,含共轭对)按重复进行分组。
+ 返回一个 list,每项:
+ {"type":"real","ap":ap_RHP,"m":m}
+ {"type":"pair","ap":ap_RHP,"ap1":ap_RHP_conj,"m":m}
+ 其中 ap = -p(RHP,Re>0),m 是重数。
+ """
+ z = np.asarray(poles, np.complex128).copy()
+ used = np.zeros(len(z), dtype=bool)
+ groups = []
+ i = 0
+ while i < len(z):
+ if used[i]:
+ i += 1
+ continue
+ p = z[i]
+ ap = -p # 你的代码里用 ap = -p(RHP)
+ if ap.real < 0:
+ raise ValueError("poles must be in the LHP")
+ if np.isclose(p.imag, 0.0, atol=tol):
+ # 实极点:统计重数
+ mask = (~used) & np.isclose(z.real, p.real, atol=tol) & np.isclose(z.imag, p.imag, atol=tol)
+ idx = np.where(mask)[0]
+ used[idx] = True
+ groups.append({"type": "real", "ap": ap, "m": len(idx)})
+ else:
+ # 找共轭
+ j = None
+ for k in range(i+1, len(z)):
+ if used[k]:
+ continue
+ if np.isclose(z[k], np.conj(p), atol=tol):
+ j = k; break
+ if j is None:
+ # 容错:没配上就当作单个复极点(很罕见);按一对处理也可
+ used[i] = True
+ groups.append({"type": "real", "ap": ap, "m": 1})
+ else:
+ # 统计这对的重数 m
+ used[i] = used[j] = True
+ m = 1
+ # 继续往后找更多相同位置的对
+ while True:
+ # 在未使用的剩余里各找一枚匹配
+ ii = None; jj = None
+ for k in range(len(z)):
+ if not used[k] and np.isclose(z[k], p, atol=tol):
+ ii = k; break
+ for k in range(len(z)):
+ if not used[k] and np.isclose(z[k], np.conj(p), atol=tol):
+ jj = k; break
+ if ii is not None and jj is not None:
+ used[ii] = used[jj] = True
+ m += 1
+ else:
+ break
+ ap1 = -np.conj(p) # = conj(ap)
+ groups.append({"type":"pair", "ap": ap, "ap1": ap1, "m": m})
+ i += 1
+ return groups
+ def generate_basis_confluent(self,s, poles, tol=1e-12):
+ """
+ 合流版正交基(与原 generate_basis 接口兼容)。
+ 返回:
+ Phi : (K, ncol) complex
+ layout : 列描述(便于后续打包 C)
+ """
+ def _all_pass_factor(s, ap):
+ # ((s - ap*)/(s + ap)),在 s=jw 上幅度为1
+ return (s - np.conj(ap)) / (s + ap)
+
+ def _all_pass_cascade(s, ap_list):
+ res = 1.0 + 0.0j
+ for ap in ap_list:
+ res *= _all_pass_factor(s, ap)
+ return res
+
+ s = np.asarray(s, np.complex128).reshape(-1)
+ groups = self._group_poles_confluent(poles, tol=tol)
+ cols = []
+ layout = []
+ prev_aps = [] # 进入当前组之前的 all-pass 级联清单(含重复)
+ Gprev = _all_pass_cascade(s, prev_aps)
+
+ for g in groups:
+ if g["type"] == "real":
+ ap, m = g["ap"], g["m"]
+ # 基函数模板:sqrt(2 Re ap) * Gprev * [(AP(ap))^r] * 1/(s+ap)
+ base = 1.0/(s + ap)
+ AP = _all_pass_factor(s, ap) # 当前极点的 all-pass
+ if m > 1:
+ pass
+ for r in range(m):
+ phi = np.sqrt(2.0*ap.real) * Gprev * (AP**r) * base
+ cols.append(phi)
+ layout.append({"type":"real", "ap":ap, "order":r+1})
+ # 更新 prev_aps & Gprev(把这一组的 m 个 ap 都加入)
+ for _ in range(m):
+ prev_aps.append(ap)
+ Gprev *= _all_pass_factor(s, ap)
+
+ elif g["type"] == "pair":
+ ap, ap1, m = g["ap"], g["ap1"], g["m"]
+ # 你的原始两列:phi1, phi2(r=0)
+ # 合流:在它们前乘 [(AP(ap)*AP(ap1))^r]
+ denom = (s + ap)*(s + ap1)
+ absap = np.abs(ap)
+ base1 = (s - absap)/denom
+ base2 = (s + absap)/denom
+ APpair = _all_pass_factor(s, ap) * _all_pass_factor(s, ap1) # 成对 all-pass
+ if m > 1:
+ pass
+ for r in range(m):
+ Gr = Gprev * (APpair**r)
+ phi1 = np.sqrt(2.0*ap.real) * Gr * base1
+ phi2 = np.sqrt(2.0*ap.real) * Gr * base2
+ cols += [phi1, phi2]
+ layout += [{"type":"pair_Re","ap":ap,"ap1":ap1,"order":r+1},
+ {"type":"pair_Im","ap":ap,"ap1":ap1,"order":r+1}]
+ # 更新 prev_aps & Gprev(把该对重复 m 次都加入)
+ for _ in range(m):
+ prev_aps.append(ap); Gprev *= _all_pass_factor(s, ap)
+ prev_aps.append(ap1); Gprev *= _all_pass_factor(s, ap1)
+
+ else:
+ raise RuntimeError("unknown group type")
+
+ Phi = np.column_stack(cols).astype(np.complex128)
+ return Phi
+
+ def matrix_A_confluent(self,poles, tol=1e-12):
+ groups = self._group_poles_confluent(poles, tol=tol)
+
+ # 先计算总列数 N(=基函数数)
+ N = 0
+ for g in groups:
+ if g["type"]=="real":
+ N += g["m"]
+ else: # pair
+ N += 2*g["m"]
+
+ A = np.zeros((N, N), float)
+
+ col_ptr = 0 # 当前要写入的列起始索引
+ for g in groups:
+ if g["type"] == "real":
+ ap, m = g["ap"], g["m"]
+ if m > 1:
+ pass
+ a = ap.real
+ for r in range(m):
+ c = col_ptr # 当前这一列的全局索引
+ # 对角项
+ A[c, c] = -a
+ # 尾部常数 2a:对所有后续行
+ if c+1 < N:
+ A[c+1:, c] += 2.0*a
+ col_ptr += 1
+
+ elif g["type"] == "pair":
+ ap, ap1, m = g["ap"], g["ap1"], g["m"]
+ a = ap.real
+ absap = float(np.abs(ap))
+ for r in range(m):
+ c0 = col_ptr # 这一对的两列 c0, c0+1
+ # 2×2 核心块(放在对角)
+ A[c0:c0+2, c0:c0+2] = np.array([[-a, -a-absap],
+ [-a+absap, -a]], float)
+ # 尾部常数 2a:对所有后续行(逐行加到这两列)
+ if c0+2 < N:
+ A[c0+2:, c0] += 2.0*a
+ A[c0+2:, c0+1] += 2.0*a
+ col_ptr += 2
+
+ else:
+ raise RuntimeError("unknown group type")
+
+ return A
def matrix_A(self, poles):
def A_col(p:np.complex128,index:int):
diff --git a/ovf/core/geometry/plot2d.py b/ovf/core/geometry/plot2d.py
index 4d825f5..51c23e3 100644
--- a/ovf/core/geometry/plot2d.py
+++ b/ovf/core/geometry/plot2d.py
@@ -5,28 +5,43 @@ from sklearn.linear_model import Ridge
import os
from ovf.schemas.geometry.basic import GeometryPlot2dUnit, GeometryPlot2dComplexDataUnit
-def plot_2d(unit: GeometryPlot2dUnit, dirname: str):
- os.makedirs(dirname, exist_ok=True)
-
- fig = make_subplots(rows=1, cols=1)
+def plot_2d(unit: GeometryPlot2dUnit, filename: str):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ scatters_data = []
+
+ colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
for index in range(len(unit.datas[0].x)):
x = [unit.datas[i].x[index] for i in range(len(unit.datas))]
y = [unit.datas[i].y[index] for i in range(len(unit.datas))]
- fig.add_trace(
- go.Scatter(
- x=x,
- y=y,
- mode='markers',
- name=data.label
- ),
- row=1, col=1
- )
-
- fig.update_layout(
- title=unit.title,
- xaxis_title=unit.x_label,
- yaxis_title=unit.y_label,
+
+ for pt in range(len(unit.datas)):
+ scatters_data.append(
+ go.Scatter(
+ x=[x[pt]],
+ y=[y[pt]],
+ mode='markers',
+ name=None,
+ marker=dict(size=8, color=colors[index%len(colors)]),
+ hovertext=f"Geometries: {unit.datas[pt].geometries}
X: {x[pt]}
Y: {y[pt]}"
+ )
+ )
+ layout = go.Layout(
+ title=f'{unit.title}',
+ xaxis=dict(title=f'{unit.x_label}'),
+ yaxis=dict(title=f'{unit.y_label}'),
+ showlegend=False
)
-
- fig.write_html(f"{dirname}/plot_2d.html")
\ No newline at end of file
+
+ fig = go.Figure(data=scatters_data, layout=layout)
+
+ if filename.endswith('.html'):
+ fig.write_html(filename)
+ elif filename.endswith('.png'):
+ fig.write_image(filename,format='png')
+ elif filename.endswith('.pdf'):
+ fig.write_image(filename,format='pdf')
+ elif filename.endswith('.jpeg') or filename.endswith('.jpg'):
+ fig.write_image(filename,format='jpeg')
+ else:
+ fig.write_html(f"{filename}.html")
\ No newline at end of file
diff --git a/ovf/core/geometry/plot_poles.py b/ovf/core/geometry/plot_poles.py
index 0d6c57c..fb904c6 100644
--- a/ovf/core/geometry/plot_poles.py
+++ b/ovf/core/geometry/plot_poles.py
@@ -3,9 +3,10 @@ import plotly.graph_objs as go
from plotly.offline import plot
import numpy as np
from sklearn.linear_model import Ridge
-from ovf.schemas.geometry.poles import PolesPlot3dUnit, PolesPlot3dDataUnit
-from ovf.schemas.geometry.basic import GeometryPlot3dDataUnit, GeometryPlot3dUnit
+from ovf.schemas.geometry.poles import PolesPlot3dUnit, PolesPlot3dDataUnit, PolesPlot2dUnit, PolesPlot2dDataUnit
+from ovf.schemas.geometry.basic import GeometryPlot3dDataUnit, GeometryPlot3dUnit,GeometryPlot2dUnit,GeometryPlot2dComplexDataUnit
from ovf.core.geometry.plot3d import get_plot_instance_in_3d
+from ovf.core.geometry.plot2d import plot_2d
from plotly.subplots import make_subplots
import os
@@ -85,4 +86,16 @@ def plot_poles_in_3d(poles:PolesPlot3dUnit,dirname:str):
index += 1
+def plot_poles_in_2d(poles:PolesPlot2dUnit,filename:str):
+ data = GeometryPlot2dUnit(
+ title=poles.title,
+ datas=[GeometryPlot2dComplexDataUnit(
+ x=[np.real(p) for p in d.poles],
+ y=[np.imag(p) for p in d.poles],
+ geometries=d.geometries
+ ) for d in poles.datas],
+ x_label=poles.x_label,
+ y_label=poles.y_label
+ )
+ plot_2d(data, filename)
\ No newline at end of file
diff --git a/ovf/schemas/geometry/basic.py b/ovf/schemas/geometry/basic.py
index d53a0a8..0396fce 100644
--- a/ovf/schemas/geometry/basic.py
+++ b/ovf/schemas/geometry/basic.py
@@ -25,7 +25,7 @@ class GeometryPlot3dUnit:
class GeometryPlot2dComplexDataUnit:
x: List[float]
y: List[float]
- geometries: Dict[str,Union[float,int,str]]
+ geometries: Dict[str,float]
@dataclass
diff --git a/ovf/schemas/geometry/poles.py b/ovf/schemas/geometry/poles.py
index 3301596..f9bd430 100644
--- a/ovf/schemas/geometry/poles.py
+++ b/ovf/schemas/geometry/poles.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
-from typing import List, Literal, Union
+from typing import List, Literal, Union, Dict
import numpy as np
@dataclass
@@ -27,4 +27,16 @@ class PolesPlot3dUnit:
geometry_2_label: str = "geometry_2"
pole_scale: Literal["linear", "log"] = "linear"
geometry_1_scale: Literal["linear", "log"] = "linear"
- geometry_2_scale: Literal["linear", "log"] = "linear"
\ No newline at end of file
+ geometry_2_scale: Literal["linear", "log"] = "linear"
+
+@dataclass
+class PolesPlot2dDataUnit:
+ poles: List[Union[complex,np.complex128]]
+ geometries: Dict[str,float]
+
+@dataclass
+class PolesPlot2dUnit:
+ title: str
+ datas: List[PolesPlot2dDataUnit]
+ x_label: str
+ y_label: str
\ No newline at end of file