From 15d1c042feb199f5f4023897b9f604de22d580ab Mon Sep 17 00:00:00 2001 From: mayge Date: Thu, 2 Oct 2025 04:34:02 -0400 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BF=AE=E5=A4=8D=E4=BA=86?= =?UTF-8?q?=E9=A2=91=E7=8E=87=E9=80=89=E7=82=B9=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ovf/core/GVFManager.py | 29 ++-- ovf/core/VFManager.py | 1 - ovf/core/geometry/plot2d.py | 3 +- ovf/core/geometry/plot_poles.py | 3 +- ovf/core/sample.py | 269 +++++++------------------------- ovf/schemas/geometry/basic.py | 1 + ovf/schemas/geometry/poles.py | 1 + 7 files changed, 83 insertions(+), 224 deletions(-) diff --git a/ovf/core/GVFManager.py b/ovf/core/GVFManager.py index 238f585..14c69e6 100644 --- a/ovf/core/GVFManager.py +++ b/ovf/core/GVFManager.py @@ -17,6 +17,7 @@ class GVFDataUnit: geometries: Dict[str,float] vf_manager:VFManager enabled:bool=True + id: int|str|None = None class GVFManager: def __init__(self): @@ -63,11 +64,12 @@ class GVFManager: min_index = 0 max_index = len(full_freqences)-1 - for i in range(len(full_freqences)): - if min_freqs is not None and full_freqences[i] < min_freqs: - min_index = i + 1 - if max_freqs is not None and full_freqences[i] > max_freqs: - max_index = i - 1 + for j in range(len(full_freqences)): + if min_freqs is not None and full_freqences[j] < min_freqs: + min_index = j + 1 + for j in range(len(full_freqences)-1,-1,-1): + if max_freqs is not None and full_freqences[j] > max_freqs: + max_index = j - 1 if parameter_type == "s": sampled_points = network.s.reshape(-1,ports,ports)[min_index:max_index+1] @@ -106,7 +108,8 @@ class GVFManager: vf.fit() geometries = datas[i]["parameters"] - self.datasets.append(GVFDataUnit(geometries=geometries, vf_manager=vf)) + id = datas[i]["id"] + self.datasets.append(GVFDataUnit(geometries=geometries, vf_manager=vf,id=id)) def plot_poles(self,save_path,degree:int,geometry_1:str,geometry_2:str): unit = PolesPlot3dUnit( @@ -137,9 +140,14 @@ class GVFManager: plot_poles_in_3d(unit, save_path) - def plot_vf_responses_with_index(self,save_dir:str,index:int,freqrange:List[float]|np.ndarray|None=None): + def plot_vf_responses_with_index(self,save_dir:str,id:int,freqrange:List[float]|np.ndarray|None=None): + index = 0 + for i in range(len(self.datasets)): + if self.datasets[i].id == id: + index = i + break + ds = self.datasets[index] - id = f"{index}" os.makedirs(save_dir,exist_ok=True) vf = ds.vf_manager vf.plot_metrics(show=False,save_path=f"{save_dir}/{id}") @@ -152,7 +160,7 @@ class GVFManager: def plot_all_vf_responses(self,save_dir:str,freqrange:List[float]|np.ndarray|None=None): for index,ds in enumerate(self.datasets): - id = f"{index}" + id = ds.id os.makedirs(save_dir,exist_ok=True) vf = ds.vf_manager vf.plot_metrics(show=False,save_path=f"{save_dir}/{id}") @@ -170,7 +178,8 @@ class GVFManager: datas=[ PolesPlot2dDataUnit( poles=ds.vf_manager.poles, - geometries=ds.geometries + geometries=ds.geometries, + id=ds.id ) for ds in self.datasets ], x_label="Real Part", diff --git a/ovf/core/VFManager.py b/ovf/core/VFManager.py index f85b5d8..4cd7574 100644 --- a/ovf/core/VFManager.py +++ b/ovf/core/VFManager.py @@ -29,7 +29,6 @@ class VFManager(): full_H, full_freqs, max_points=max_points if max_points is not None else len(full_freqs), - dc_exclude=dc_enforce ) self.freqs=freqs diff --git a/ovf/core/geometry/plot2d.py b/ovf/core/geometry/plot2d.py index 51c23e3..77c8cca 100644 --- a/ovf/core/geometry/plot2d.py +++ b/ovf/core/geometry/plot2d.py @@ -16,6 +16,7 @@ def plot_2d(unit: GeometryPlot2dUnit, filename: str): y = [unit.datas[i].y[index] for i in range(len(unit.datas))] for pt in range(len(unit.datas)): + id_text = f"Id: {unit.datas[pt].id}
" if unit.datas[pt].id is not None else "" scatters_data.append( go.Scatter( x=[x[pt]], @@ -23,7 +24,7 @@ def plot_2d(unit: GeometryPlot2dUnit, filename: str): mode='markers', name=None, marker=dict(size=8, color=colors[index%len(colors)]), - hovertext=f"Geometries: {unit.datas[pt].geometries}
X: {x[pt]}
Y: {y[pt]}" + hovertext=id_text+f"Geometries: {unit.datas[pt].geometries}
X: {x[pt]}
Y: {y[pt]}" ) ) layout = go.Layout( diff --git a/ovf/core/geometry/plot_poles.py b/ovf/core/geometry/plot_poles.py index fb904c6..b4993c9 100644 --- a/ovf/core/geometry/plot_poles.py +++ b/ovf/core/geometry/plot_poles.py @@ -92,7 +92,8 @@ def plot_poles_in_2d(poles:PolesPlot2dUnit,filename:str): datas=[GeometryPlot2dComplexDataUnit( x=[np.real(p) for p in d.poles], y=[np.imag(p) for p in d.poles], - geometries=d.geometries + geometries=d.geometries, + id=d.id ) for d in poles.datas], x_label=poles.x_label, y_label=poles.y_label diff --git a/ovf/core/sample.py b/ovf/core/sample.py index 234daad..1fd888a 100644 --- a/ovf/core/sample.py +++ b/ovf/core/sample.py @@ -1,138 +1,19 @@ import numpy as np -def _ensure_index_set(obj): - """ - 把 obj 规整为 {int, int, ...} 的集合。 - 支持: 单个标量、list/tuple/set、numpy 数组、以及嵌套的这些类型。 - """ - out = set() - if obj is None: - return out - # 可迭代?尝试逐个展开 - try: - it = iter(obj) - except TypeError: - # 标量 - out.add(int(obj)) - return out - - for x in it: - if isinstance(x, (set, list, tuple, np.ndarray)): - out.update(_ensure_index_set(x)) - else: - out.add(int(x)) - return out - -def _dc_outlier_indices(H, freq, - # 自适应阈值参数 - tail_prob=1e-4, # 目标误报率 α;越小阈值越严格 - baseline_high_frac=0.5, # 用最高的这部分频点作为“正常”基线 - gap_tol=1 # 连续段内允许的非异常“间隙”点数 - ): - """ - 自适应 DC 异常剔除: - - 在高频“基线”上用 MAD 估计 log(|Re H|), log(|H|) 的位置与尺度,给出阈值; - - 从低频向高频寻找连续超阈值的前缀区间,作为 DC 异常,返回其索引。 - """ - H = np.asarray(H, np.complex128).reshape(-1) - f = np.asarray(freq, float).reshape(-1) - N = f.size - if N == 0: - return np.array([], dtype=int) - - # 频率按绝对值排序,便于兼容可能存在的负频或非单调 - ff = np.abs(f) - order = np.argsort(ff) - inv_order = np.empty_like(order) - inv_order[order] = np.arange(N) - - # 选择“高频基线”集合 - if not (0.0 < baseline_high_frac <= 1.0): - baseline_high_frac = 0.5 - q = np.quantile(ff, 1.0 - baseline_high_frac) - baseline_mask = ff >= q - if not np.any(baseline_mask): - # 极端情形退化为全体 - baseline_mask[:] = True - - # 在 log10 域做稳健阈值 - eps = 1e-30 - x_real = np.log10(np.abs(H.real) + eps) - x_mag = np.log10(np.abs(H) + eps) - - def _mad_threshold(x, mask, alpha): - xb = x[mask] - med = np.median(xb) - mad = np.median(np.abs(xb - med)) - scale = 1.4826 * max(mad, 1e-16) - # 需要 scipy.stats.norm.isf;如不可用则退化常数近似 - try: - from scipy.stats import norm - z = float(norm.isf(alpha)) - except Exception: - # 粗略近似:alpha=1e-4→3.72, 1e-5→4.27, 1e-6→4.75 - z = 3.72 if alpha >= 1e-4 else (4.27 if alpha >= 1e-5 else 4.75) - return med + z * scale - - t_real = _mad_threshold(x_real, baseline_mask, tail_prob) - t_mag = _mad_threshold(x_mag, baseline_mask, tail_prob) - - # 判定“异常”:任一统计量超阈值即记为异常 - is_outlier = (x_real > t_real) | (x_mag > t_mag) - - # 在低频排序下,寻找从头开始的“连续异常段”(允许 gap_tol 个非异常间隙) - mask_ord = is_outlier[order] - run_end = 0 - gaps = 0 - for k in range(N): - if mask_ord[k]: - run_end = k + 1 - else: - gaps += 1 - if gaps > gap_tol: - break - dc_run_idx = order[:run_end] - - # 仅当这段确实集中在低频端且规模非零时才返回 - return np.asarray(dc_run_idx, dtype=int) - - def _auto_select_indices(H, freq, n_baseline=64, peak_prominence=0.05, peak_window=5, topgrad_q=0.98, max_points=25, - ensure_ends=True, - # 新增:DC 极大点自动剔除参数 - dc_exclude=True, - dc_frac=0.02, - dc_real_factor=1e3, - dc_mag_factor=1e3, - # 额外手动排除(可选) - extra_exclude=None): + ensure_ends=True): """返回选中的全局索引,避免直接切片导致多端口不对齐。""" H = np.asarray(H).astype(np.complex128).reshape(-1) f = np.asarray(freq).astype(float).reshape(-1) if H.size != f.size: raise ValueError("H and freq must have the same length.") N = f.size - - # ---- 计算需要排除的索引(自动频率截断 + 手动) ---- - exclude = set() - if dc_exclude: - exclude.update(_dc_outlier_indices(H, f).tolist()) - if extra_exclude is not None: - exclude.update(_ensure_index_set(extra_exclude)) - - all_idx = set(range(N)) - allowed = sorted(all_idx - exclude) - if not allowed: - # 兜底:如果全被排除,至少保留最后一个点 - allowed = [N - 1] - - # 无需裁剪时,也返回“已剔除极大点”的全集 if N < 4 or max_points is None or max_points >= N: - return np.array(allowed, dtype=int) + return np.arange(N, dtype=int) eps = 1e-16 mag = np.abs(H) @@ -146,50 +27,39 @@ def _auto_select_indices(H, freq, idx = set() if ensure_ends: - for end in (0, N - 1): - if end in all_idx and end not in exclude: - idx.add(end) + idx.update([0, N - 1]) if n_baseline > 0: - grid = np.linspace(lf[min(allowed)], lf[max(allowed)], n_baseline) + grid = np.linspace(lf.min(), lf.max(), n_baseline) base_idx = np.clip(np.searchsorted(lf, grid), 0, N - 1) - # 只保留 allowed 中的 - idx.update([b for b in np.unique(base_idx) if b in allowed]) + idx.update(np.unique(base_idx).tolist()) - # 峰值(在 allowed 上扩窗) try: from scipy.signal import find_peaks - dyn = logmag[list(allowed)].max() - logmag[list(allowed)].min() + dyn = logmag.max() - logmag.min() prom = peak_prominence * (dyn + 1e-12) peaks, _ = find_peaks(logmag, prominence=prom) except Exception: peaks = np.where((mag[1:-1] > mag[:-2]) & (mag[1:-1] > mag[2:]))[0] + 1 - peaks = np.asarray([p for p in np.atleast_1d(peaks) if p in allowed], dtype=int) - for p in peaks: + for p in np.atleast_1d(peaks): lo = max(0, int(p) - peak_window) hi = min(N, int(p) + peak_window + 1) - idx.update([q for q in range(lo, hi) if q in allowed]) + idx.update(range(lo, hi)) - # 斜率/相位剧变(仅在 allowed 内) - abs_dlog = np.abs(d_logmag[list(allowed)]) - abs_dph = np.abs(d_phase[list(allowed)]) - thr_slope = np.quantile(abs_dlog, topgrad_q) if abs_dlog.size else np.inf - thr_phase = np.quantile(abs_dph, topgrad_q) if abs_dph.size else np.inf - - for q in allowed: - if (abs(d_logmag[q]) >= thr_slope) or (abs(d_phase[q]) >= thr_phase): - idx.add(q) + thr_slope = np.quantile(np.abs(d_logmag), topgrad_q) + thr_phase = np.quantile(np.abs(d_phase), topgrad_q) + idx.update(np.where(np.abs(d_logmag) >= thr_slope)[0].tolist()) + idx.update(np.where(np.abs(d_phase) >= thr_phase)[0].tolist()) sel = np.array(sorted(idx), dtype=int) - # 超额则按优先级裁剪 if sel.size > max_points: priority = np.zeros(sel.size, dtype=int) if ensure_ends: priority[(sel == 0) | (sel == N - 1)] = 3 - if peaks.size: - mask = np.isin(sel, peaks) + if np.size(peaks): + mask = np.isin(sel, np.atleast_1d(peaks)) priority[mask] = np.maximum(priority[mask], 2) keep = [] @@ -209,65 +79,60 @@ def _auto_select_indices(H, freq, break sel = np.array(sorted(set(keep)), dtype=int) - # 不足则在 allowed 中均匀补点 if sel.size < max_points: - remaining = sorted(set(allowed) - set(sel)) - if remaining: - need = max_points - sel.size - step = max(1, int(np.ceil(len(remaining) / need))) - sel = np.concatenate([sel, np.array(remaining[::step][:need], dtype=int)]) - sel = np.array(sorted(set(sel)), dtype=int) + all_idx = set(range(N)) + missing = list(sorted(all_idx - set(sel))) + n_missing = max_points - sel.size + if n_missing > 0 and missing: + extra = np.linspace(0, len(missing) - 1, n_missing, dtype=int) + sel = np.concatenate([sel, np.array(missing)[extra]]) + sel = np.array(sorted(set(sel)), dtype=int) + if sel.size < max_points: + left = list(sorted(all_idx - set(sel))) + if left: + add = min(max_points - sel.size, len(left)) + sel = np.concatenate([sel, np.random.choice(left, add, replace=False)]) + sel = np.array(sorted(set(sel)), dtype=int) sel = sel[:max_points] - # 再次确保都在 allowed - sel = np.array([i for i in sel if i in allowed], dtype=int) - if sel.size == 0: - sel = np.array([allowed[0]], dtype=int) return sel def auto_select(H, freq, - n_baseline=64, - peak_prominence=0.05, - peak_window=5, - topgrad_q=0.98, - max_points=25, - ensure_ends=True, - # 新增:自动 DC 剔除参数透传 - dc_exclude=True, - dc_frac=0.02, - dc_real_factor=1e3, - dc_mag_factor=1e3, - extra_exclude=None): + n_baseline=64, # log-spaced backbone points + peak_prominence=0.05, # fraction of |H| dB dynamic range for peak detection + peak_window=5, # take ±peak_window samples around each peak + topgrad_q=0.98, # keep top 2% largest slope/phase-change points + max_points=25, # final cap on selected samples (None = no cap) + ensure_ends=True): sel = _auto_select_indices(H, freq, n_baseline=n_baseline, peak_prominence=peak_prominence, peak_window=peak_window, topgrad_q=topgrad_q, max_points=max_points, - ensure_ends=ensure_ends, - dc_exclude=dc_exclude, - dc_frac=dc_frac, - dc_real_factor=dc_real_factor, - dc_mag_factor=dc_mag_factor, - extra_exclude=extra_exclude) + ensure_ends=ensure_ends) H = np.asarray(H).astype(np.complex128).reshape(-1) f = np.asarray(freq).astype(float).reshape(-1) return H[sel], f[sel] - def auto_select_multple_ports(H, freq, - n_baseline=64, - peak_prominence=0.05, - peak_window=5, - topgrad_q=0.98, - max_points=25, - ensure_ends=True, - # 新增:自动 DC 剔除参数 - dc_exclude=True, - extra_exclude=None): + n_baseline=64, # log-spaced backbone points + peak_prominence=0.05, # fraction of |H| dB dynamic range for peak detection + peak_window=5, # take ±peak_window samples around each peak + topgrad_q=0.98, # keep top 2% largest slope/phase-change points + max_points=25, # final cap on selected samples (None = no cap) + ensure_ends=True): """ - 多端口统一选点,加入“自动频率截断(DC 极大点剔除)”。 + 多端口统一选点:为每个(i,j)先各自产出候选索引,然后做并集并按出现次数优先级裁剪到同一套索引, + 确保返回的 H_selected 与 freq_selected 全端口严格对齐。 + 输入: + H: (N, P, P) 复数频响 + freq: (N,) 频率 + 返回: + H_selected: (K, P, P) + freq_selected: (K,) + 其中 K == max_points(或当样本不足时为 N)。 """ H = np.asarray(H) f = np.asarray(freq).astype(float).reshape(-1) @@ -279,27 +144,11 @@ def auto_select_multple_ports(H, freq, if P1 != P2: raise ValueError("H must be square on ports (P x P).") - # ---- 汇总所有端口的“应剔除”索引 ---- - global_exclude = set() - if dc_exclude: - for i in range(P1): - for j in range(P2): - excl_ij = _dc_outlier_indices(H[:, i, j], f) - global_exclude.update(excl_ij.tolist()) - if extra_exclude is not None: - global_exclude.update(np.asarray(extra_exclude, dtype=int).tolist()) - - allowed = sorted(set(range(N)) - global_exclude) - if not allowed: - # 兜底 - allowed = [N - 1] - - # 边界:无需裁剪时,也只返回 allowed(已剔除 DC 极大点) + # 边界:样本太少或不需裁剪,直接返回全量且对齐 if N < 4 or max_points is None or max_points >= N: - sel_final = np.array(allowed, dtype=int) - return H[sel_final], f[sel_final] + return H.copy(), f.copy() - # 每个(i,j)各自选索引(带 exclude) + # 每个(i,j)各自选索引 counts = {} all_sel_sets = [] for i in range(P1): @@ -310,28 +159,26 @@ def auto_select_multple_ports(H, freq, peak_window=peak_window, topgrad_q=topgrad_q, max_points=max_points, - ensure_ends=ensure_ends, - dc_exclude=False, # 这里已全局剔除过 - extra_exclude=global_exclude) - # 仅保留 allowed 的 - sel = np.array([k for k in sel if k in allowed], dtype=int) + ensure_ends=ensure_ends) all_sel_sets.append(sel) for idx in sel.tolist(): counts[idx] = counts.get(idx, 0) + 1 - # 并集 + 频次优先到 max_points(限定在 allowed) - union_idx = sorted(set(np.concatenate(all_sel_sets)) & set(allowed)) + # 并集 + 频次优先裁剪到 max_points + union_idx = sorted(set(np.concatenate(all_sel_sets)) ) + # 如果并集不超过预算,必要时补点至 max_points(均匀抽取未选样本) if len(union_idx) <= max_points: sel_final = union_idx if len(sel_final) < max_points: - remaining = sorted(set(allowed) - set(sel_final)) + remaining = sorted(set(range(N)) - set(sel_final)) if remaining: need = max_points - len(sel_final) step = max(1, int(np.ceil(len(remaining) / need))) sel_final.extend(remaining[::step][:need]) sel_final = sorted(set(sel_final))[:max_points] else: + # 过多则按出现次数从高到低选,出现次数相同按索引位置靠前优先 sorted_by_score = sorted(union_idx, key=lambda k: (-counts.get(k, 0), k)) sel_final = sorted(sorted_by_score[:max_points]) diff --git a/ovf/schemas/geometry/basic.py b/ovf/schemas/geometry/basic.py index 0396fce..561d73b 100644 --- a/ovf/schemas/geometry/basic.py +++ b/ovf/schemas/geometry/basic.py @@ -26,6 +26,7 @@ class GeometryPlot2dComplexDataUnit: x: List[float] y: List[float] geometries: Dict[str,float] + id: Union[int,str]|None = None @dataclass diff --git a/ovf/schemas/geometry/poles.py b/ovf/schemas/geometry/poles.py index f9bd430..367bc27 100644 --- a/ovf/schemas/geometry/poles.py +++ b/ovf/schemas/geometry/poles.py @@ -33,6 +33,7 @@ class PolesPlot3dUnit: class PolesPlot2dDataUnit: poles: List[Union[complex,np.complex128]] geometries: Dict[str,float] + id: Union[int,str]|None = None @dataclass class PolesPlot2dUnit: