feat: 修改了自动选取采样点的函数
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,4 +11,5 @@ test/
|
||||
outputs/
|
||||
ovf.egg-info/
|
||||
dist/
|
||||
build/
|
||||
build/
|
||||
log/
|
||||
16731
logs/metrics.jsonl
Normal file
16731
logs/metrics.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,6 @@ import numpy as np
|
||||
import json
|
||||
import skrf as rf
|
||||
import os
|
||||
from ovf.core.sample import auto_select_multple_ports
|
||||
from ovf.core.basis.MultiPortOrthonormalBasis import MultiPortOrthonormalBasis
|
||||
from ovf.core.geometry.plot_poles import plot_poles_in_3d
|
||||
from ovf.schemas.geometry.poles import PolesPlot3dUnit,PolesPlot3dDataUnit
|
||||
@@ -24,6 +23,7 @@ class GVFManager:
|
||||
self.parameter_type: Literal["s","y","z"] = "s"
|
||||
self.datasets: List[GVFDataUnit] = []
|
||||
self.full_freqs: np.ndarray | None = None
|
||||
self.dc_enforce:bool = True
|
||||
|
||||
def save(self,filepath:str):
|
||||
os.makedirs(os.path.dirname(filepath),exist_ok=True)
|
||||
@@ -45,9 +45,11 @@ class GVFManager:
|
||||
basis:type=MultiPortOrthonormalBasis,
|
||||
parameter_type:Literal["s","y","z"]="s",
|
||||
min_freqs:float|None=None,
|
||||
max_freqs:float|None=None
|
||||
max_freqs:float|None=None,
|
||||
dc_enforce:bool=True
|
||||
):
|
||||
self.parameter_type = parameter_type
|
||||
self.dc_enforce = dc_enforce
|
||||
with open(jsonfile,"r") as f:
|
||||
datas = json.load(f)
|
||||
|
||||
@@ -76,11 +78,31 @@ class GVFManager:
|
||||
|
||||
self.full_freqs = full_freqences[min_index:max_index+1]
|
||||
|
||||
if max_points:
|
||||
H,freqs = auto_select_multple_ports(sampled_points,self.full_freqs,max_points=max_points)
|
||||
else:
|
||||
H,freqs = sampled_points,self.full_freqs
|
||||
vf = VFManager(npoles_cplx=npoles_cplx,freqs=freqs,H=H,model=basis,iterations=K,verbose=False)
|
||||
# if max_points:
|
||||
# H,freqs = auto_select_multple_ports(
|
||||
# sampled_points,
|
||||
# self.full_freqs,
|
||||
# max_points=max_points,
|
||||
# dc_exclude=self.dc_enforce
|
||||
# )
|
||||
# else:
|
||||
# # H,freqs = sampled_points,self.full_freqs
|
||||
# H,freqs = auto_select_multple_ports(
|
||||
# sampled_points,
|
||||
# self.full_freqs,
|
||||
# max_points=len(self.full_freqs),
|
||||
# dc_exclude=self.dc_enforce
|
||||
# )
|
||||
vf = VFManager(
|
||||
npoles_cplx=npoles_cplx,
|
||||
full_freqs=self.full_freqs,
|
||||
full_H=sampled_points,
|
||||
model=basis,
|
||||
iterations=K,
|
||||
verbose=False,
|
||||
dc_enforce=self.dc_enforce,
|
||||
max_points=max_points
|
||||
)
|
||||
vf.fit()
|
||||
|
||||
geometries = datas[i]["parameters"]
|
||||
|
||||
@@ -5,22 +5,32 @@ from ovf.core.basis.MultiPortOrthonormalBasis import MultiPortOrthonormalBasis
|
||||
from ovf.core.utils import generate_starting_poles
|
||||
import json
|
||||
import pickle
|
||||
from ovf.core.sample import auto_select_multple_ports
|
||||
|
||||
class VFManager():
|
||||
def __init__(
|
||||
self,
|
||||
npoles_cplx,
|
||||
freqs,
|
||||
H,
|
||||
full_freqs,
|
||||
full_H,
|
||||
model=MultiPortOrthonormalBasis,
|
||||
iterations:int=5,
|
||||
fit_constant:bool=True,
|
||||
fit_proportional:bool=False,
|
||||
dc_enforce:bool=False,
|
||||
passivity_enforce:bool=True,
|
||||
max_points:int|None=None,
|
||||
verbose:bool=True
|
||||
):
|
||||
|
||||
self.full_freqs = full_freqs
|
||||
self.full_H = full_H
|
||||
H,freqs = auto_select_multple_ports(
|
||||
full_H,
|
||||
full_freqs,
|
||||
max_points=max_points if max_points is not None else len(full_freqs),
|
||||
dc_exclude=dc_enforce
|
||||
)
|
||||
|
||||
self.freqs=freqs
|
||||
self.H=H
|
||||
@@ -152,8 +162,8 @@ class VFManager():
|
||||
def _load_npz(cls,filename):
|
||||
instance = cls(
|
||||
npoles_cplx=1, # 临时值,稍后会被覆盖
|
||||
freqs=np.array([]), # 临时值,稍后会被覆盖
|
||||
H=np.array([[]]), # 临时值,稍后会被覆盖
|
||||
full_freqs=np.array([]), # 临时值,稍后会被覆盖
|
||||
full_H=np.array([[]]), # 临时值,稍后会被覆盖
|
||||
model=MultiPortOrthonormalBasis, # 临时值,稍后会被覆盖
|
||||
iterations=1, # 临时值,稍后会被覆盖
|
||||
verbose=False # 临时值,稍后会被覆盖
|
||||
|
||||
@@ -1,19 +1,138 @@
|
||||
import numpy as np
|
||||
def _ensure_index_set(obj):
|
||||
"""
|
||||
把 obj 规整为 {int, int, ...} 的集合。
|
||||
支持: 单个标量、list/tuple/set、numpy 数组、以及嵌套的这些类型。
|
||||
"""
|
||||
out = set()
|
||||
if obj is None:
|
||||
return out
|
||||
# 可迭代?尝试逐个展开
|
||||
try:
|
||||
it = iter(obj)
|
||||
except TypeError:
|
||||
# 标量
|
||||
out.add(int(obj))
|
||||
return out
|
||||
|
||||
for x in it:
|
||||
if isinstance(x, (set, list, tuple, np.ndarray)):
|
||||
out.update(_ensure_index_set(x))
|
||||
else:
|
||||
out.add(int(x))
|
||||
return out
|
||||
|
||||
def _dc_outlier_indices(H, freq,
|
||||
# 自适应阈值参数
|
||||
tail_prob=1e-4, # 目标误报率 α;越小阈值越严格
|
||||
baseline_high_frac=0.5, # 用最高的这部分频点作为“正常”基线
|
||||
gap_tol=1 # 连续段内允许的非异常“间隙”点数
|
||||
):
|
||||
"""
|
||||
自适应 DC 异常剔除:
|
||||
- 在高频“基线”上用 MAD 估计 log(|Re H|), log(|H|) 的位置与尺度,给出阈值;
|
||||
- 从低频向高频寻找连续超阈值的前缀区间,作为 DC 异常,返回其索引。
|
||||
"""
|
||||
H = np.asarray(H, np.complex128).reshape(-1)
|
||||
f = np.asarray(freq, float).reshape(-1)
|
||||
N = f.size
|
||||
if N == 0:
|
||||
return np.array([], dtype=int)
|
||||
|
||||
# 频率按绝对值排序,便于兼容可能存在的负频或非单调
|
||||
ff = np.abs(f)
|
||||
order = np.argsort(ff)
|
||||
inv_order = np.empty_like(order)
|
||||
inv_order[order] = np.arange(N)
|
||||
|
||||
# 选择“高频基线”集合
|
||||
if not (0.0 < baseline_high_frac <= 1.0):
|
||||
baseline_high_frac = 0.5
|
||||
q = np.quantile(ff, 1.0 - baseline_high_frac)
|
||||
baseline_mask = ff >= q
|
||||
if not np.any(baseline_mask):
|
||||
# 极端情形退化为全体
|
||||
baseline_mask[:] = True
|
||||
|
||||
# 在 log10 域做稳健阈值
|
||||
eps = 1e-30
|
||||
x_real = np.log10(np.abs(H.real) + eps)
|
||||
x_mag = np.log10(np.abs(H) + eps)
|
||||
|
||||
def _mad_threshold(x, mask, alpha):
|
||||
xb = x[mask]
|
||||
med = np.median(xb)
|
||||
mad = np.median(np.abs(xb - med))
|
||||
scale = 1.4826 * max(mad, 1e-16)
|
||||
# 需要 scipy.stats.norm.isf;如不可用则退化常数近似
|
||||
try:
|
||||
from scipy.stats import norm
|
||||
z = float(norm.isf(alpha))
|
||||
except Exception:
|
||||
# 粗略近似:alpha=1e-4→3.72, 1e-5→4.27, 1e-6→4.75
|
||||
z = 3.72 if alpha >= 1e-4 else (4.27 if alpha >= 1e-5 else 4.75)
|
||||
return med + z * scale
|
||||
|
||||
t_real = _mad_threshold(x_real, baseline_mask, tail_prob)
|
||||
t_mag = _mad_threshold(x_mag, baseline_mask, tail_prob)
|
||||
|
||||
# 判定“异常”:任一统计量超阈值即记为异常
|
||||
is_outlier = (x_real > t_real) | (x_mag > t_mag)
|
||||
|
||||
# 在低频排序下,寻找从头开始的“连续异常段”(允许 gap_tol 个非异常间隙)
|
||||
mask_ord = is_outlier[order]
|
||||
run_end = 0
|
||||
gaps = 0
|
||||
for k in range(N):
|
||||
if mask_ord[k]:
|
||||
run_end = k + 1
|
||||
else:
|
||||
gaps += 1
|
||||
if gaps > gap_tol:
|
||||
break
|
||||
dc_run_idx = order[:run_end]
|
||||
|
||||
# 仅当这段确实集中在低频端且规模非零时才返回
|
||||
return np.asarray(dc_run_idx, dtype=int)
|
||||
|
||||
|
||||
def _auto_select_indices(H, freq,
|
||||
n_baseline=64,
|
||||
peak_prominence=0.05,
|
||||
peak_window=5,
|
||||
topgrad_q=0.98,
|
||||
max_points=25,
|
||||
ensure_ends=True):
|
||||
ensure_ends=True,
|
||||
# 新增:DC 极大点自动剔除参数
|
||||
dc_exclude=True,
|
||||
dc_frac=0.02,
|
||||
dc_real_factor=1e3,
|
||||
dc_mag_factor=1e3,
|
||||
# 额外手动排除(可选)
|
||||
extra_exclude=None):
|
||||
"""返回选中的全局索引,避免直接切片导致多端口不对齐。"""
|
||||
H = np.asarray(H).astype(np.complex128).reshape(-1)
|
||||
f = np.asarray(freq).astype(float).reshape(-1)
|
||||
if H.size != f.size:
|
||||
raise ValueError("H and freq must have the same length.")
|
||||
N = f.size
|
||||
|
||||
# ---- 计算需要排除的索引(自动频率截断 + 手动) ----
|
||||
exclude = set()
|
||||
if dc_exclude:
|
||||
exclude.update(_dc_outlier_indices(H, f).tolist())
|
||||
if extra_exclude is not None:
|
||||
exclude.update(_ensure_index_set(extra_exclude))
|
||||
|
||||
all_idx = set(range(N))
|
||||
allowed = sorted(all_idx - exclude)
|
||||
if not allowed:
|
||||
# 兜底:如果全被排除,至少保留最后一个点
|
||||
allowed = [N - 1]
|
||||
|
||||
# 无需裁剪时,也返回“已剔除极大点”的全集
|
||||
if N < 4 or max_points is None or max_points >= N:
|
||||
return np.arange(N, dtype=int)
|
||||
return np.array(allowed, dtype=int)
|
||||
|
||||
eps = 1e-16
|
||||
mag = np.abs(H)
|
||||
@@ -27,39 +146,50 @@ def _auto_select_indices(H, freq,
|
||||
|
||||
idx = set()
|
||||
if ensure_ends:
|
||||
idx.update([0, N - 1])
|
||||
for end in (0, N - 1):
|
||||
if end in all_idx and end not in exclude:
|
||||
idx.add(end)
|
||||
|
||||
if n_baseline > 0:
|
||||
grid = np.linspace(lf.min(), lf.max(), n_baseline)
|
||||
grid = np.linspace(lf[min(allowed)], lf[max(allowed)], n_baseline)
|
||||
base_idx = np.clip(np.searchsorted(lf, grid), 0, N - 1)
|
||||
idx.update(np.unique(base_idx).tolist())
|
||||
# 只保留 allowed 中的
|
||||
idx.update([b for b in np.unique(base_idx) if b in allowed])
|
||||
|
||||
# 峰值(在 allowed 上扩窗)
|
||||
try:
|
||||
from scipy.signal import find_peaks
|
||||
dyn = logmag.max() - logmag.min()
|
||||
dyn = logmag[list(allowed)].max() - logmag[list(allowed)].min()
|
||||
prom = peak_prominence * (dyn + 1e-12)
|
||||
peaks, _ = find_peaks(logmag, prominence=prom)
|
||||
except Exception:
|
||||
peaks = np.where((mag[1:-1] > mag[:-2]) & (mag[1:-1] > mag[2:]))[0] + 1
|
||||
|
||||
for p in np.atleast_1d(peaks):
|
||||
peaks = np.asarray([p for p in np.atleast_1d(peaks) if p in allowed], dtype=int)
|
||||
for p in peaks:
|
||||
lo = max(0, int(p) - peak_window)
|
||||
hi = min(N, int(p) + peak_window + 1)
|
||||
idx.update(range(lo, hi))
|
||||
idx.update([q for q in range(lo, hi) if q in allowed])
|
||||
|
||||
thr_slope = np.quantile(np.abs(d_logmag), topgrad_q)
|
||||
thr_phase = np.quantile(np.abs(d_phase), topgrad_q)
|
||||
idx.update(np.where(np.abs(d_logmag) >= thr_slope)[0].tolist())
|
||||
idx.update(np.where(np.abs(d_phase) >= thr_phase)[0].tolist())
|
||||
# 斜率/相位剧变(仅在 allowed 内)
|
||||
abs_dlog = np.abs(d_logmag[list(allowed)])
|
||||
abs_dph = np.abs(d_phase[list(allowed)])
|
||||
thr_slope = np.quantile(abs_dlog, topgrad_q) if abs_dlog.size else np.inf
|
||||
thr_phase = np.quantile(abs_dph, topgrad_q) if abs_dph.size else np.inf
|
||||
|
||||
for q in allowed:
|
||||
if (abs(d_logmag[q]) >= thr_slope) or (abs(d_phase[q]) >= thr_phase):
|
||||
idx.add(q)
|
||||
|
||||
sel = np.array(sorted(idx), dtype=int)
|
||||
|
||||
# 超额则按优先级裁剪
|
||||
if sel.size > max_points:
|
||||
priority = np.zeros(sel.size, dtype=int)
|
||||
if ensure_ends:
|
||||
priority[(sel == 0) | (sel == N - 1)] = 3
|
||||
if np.size(peaks):
|
||||
mask = np.isin(sel, np.atleast_1d(peaks))
|
||||
if peaks.size:
|
||||
mask = np.isin(sel, peaks)
|
||||
priority[mask] = np.maximum(priority[mask], 2)
|
||||
|
||||
keep = []
|
||||
@@ -79,60 +209,65 @@ def _auto_select_indices(H, freq,
|
||||
break
|
||||
sel = np.array(sorted(set(keep)), dtype=int)
|
||||
|
||||
# 不足则在 allowed 中均匀补点
|
||||
if sel.size < max_points:
|
||||
all_idx = set(range(N))
|
||||
missing = list(sorted(all_idx - set(sel)))
|
||||
n_missing = max_points - sel.size
|
||||
if n_missing > 0 and missing:
|
||||
extra = np.linspace(0, len(missing) - 1, n_missing, dtype=int)
|
||||
sel = np.concatenate([sel, np.array(missing)[extra]])
|
||||
sel = np.array(sorted(set(sel)), dtype=int)
|
||||
if sel.size < max_points:
|
||||
left = list(sorted(all_idx - set(sel)))
|
||||
if left:
|
||||
add = min(max_points - sel.size, len(left))
|
||||
sel = np.concatenate([sel, np.random.choice(left, add, replace=False)])
|
||||
sel = np.array(sorted(set(sel)), dtype=int)
|
||||
remaining = sorted(set(allowed) - set(sel))
|
||||
if remaining:
|
||||
need = max_points - sel.size
|
||||
step = max(1, int(np.ceil(len(remaining) / need)))
|
||||
sel = np.concatenate([sel, np.array(remaining[::step][:need], dtype=int)])
|
||||
sel = np.array(sorted(set(sel)), dtype=int)
|
||||
sel = sel[:max_points]
|
||||
|
||||
# 再次确保都在 allowed
|
||||
sel = np.array([i for i in sel if i in allowed], dtype=int)
|
||||
if sel.size == 0:
|
||||
sel = np.array([allowed[0]], dtype=int)
|
||||
return sel
|
||||
|
||||
|
||||
def auto_select(H, freq,
|
||||
n_baseline=64, # log-spaced backbone points
|
||||
peak_prominence=0.05, # fraction of |H| dB dynamic range for peak detection
|
||||
peak_window=5, # take ±peak_window samples around each peak
|
||||
topgrad_q=0.98, # keep top 2% largest slope/phase-change points
|
||||
max_points=25, # final cap on selected samples (None = no cap)
|
||||
ensure_ends=True):
|
||||
n_baseline=64,
|
||||
peak_prominence=0.05,
|
||||
peak_window=5,
|
||||
topgrad_q=0.98,
|
||||
max_points=25,
|
||||
ensure_ends=True,
|
||||
# 新增:自动 DC 剔除参数透传
|
||||
dc_exclude=True,
|
||||
dc_frac=0.02,
|
||||
dc_real_factor=1e3,
|
||||
dc_mag_factor=1e3,
|
||||
extra_exclude=None):
|
||||
sel = _auto_select_indices(H, freq,
|
||||
n_baseline=n_baseline,
|
||||
peak_prominence=peak_prominence,
|
||||
peak_window=peak_window,
|
||||
topgrad_q=topgrad_q,
|
||||
max_points=max_points,
|
||||
ensure_ends=ensure_ends)
|
||||
ensure_ends=ensure_ends,
|
||||
dc_exclude=dc_exclude,
|
||||
dc_frac=dc_frac,
|
||||
dc_real_factor=dc_real_factor,
|
||||
dc_mag_factor=dc_mag_factor,
|
||||
extra_exclude=extra_exclude)
|
||||
H = np.asarray(H).astype(np.complex128).reshape(-1)
|
||||
f = np.asarray(freq).astype(float).reshape(-1)
|
||||
return H[sel], f[sel]
|
||||
|
||||
|
||||
def auto_select_multple_ports(H, freq,
|
||||
n_baseline=64, # log-spaced backbone points
|
||||
peak_prominence=0.05, # fraction of |H| dB dynamic range for peak detection
|
||||
peak_window=5, # take ±peak_window samples around each peak
|
||||
topgrad_q=0.98, # keep top 2% largest slope/phase-change points
|
||||
max_points=25, # final cap on selected samples (None = no cap)
|
||||
ensure_ends=True):
|
||||
n_baseline=64,
|
||||
peak_prominence=0.05,
|
||||
peak_window=5,
|
||||
topgrad_q=0.98,
|
||||
max_points=25,
|
||||
ensure_ends=True,
|
||||
# 新增:自动 DC 剔除参数
|
||||
dc_exclude=True,
|
||||
extra_exclude=None):
|
||||
"""
|
||||
多端口统一选点:为每个(i,j)先各自产出候选索引,然后做并集并按出现次数优先级裁剪到同一套索引,
|
||||
确保返回的 H_selected 与 freq_selected 全端口严格对齐。
|
||||
输入:
|
||||
H: (N, P, P) 复数频响
|
||||
freq: (N,) 频率
|
||||
返回:
|
||||
H_selected: (K, P, P)
|
||||
freq_selected: (K,)
|
||||
其中 K == max_points(或当样本不足时为 N)。
|
||||
多端口统一选点,加入“自动频率截断(DC 极大点剔除)”。
|
||||
"""
|
||||
H = np.asarray(H)
|
||||
f = np.asarray(freq).astype(float).reshape(-1)
|
||||
@@ -144,11 +279,27 @@ def auto_select_multple_ports(H, freq,
|
||||
if P1 != P2:
|
||||
raise ValueError("H must be square on ports (P x P).")
|
||||
|
||||
# 边界:样本太少或不需裁剪,直接返回全量且对齐
|
||||
if N < 4 or max_points is None or max_points >= N:
|
||||
return H.copy(), f.copy()
|
||||
# ---- 汇总所有端口的“应剔除”索引 ----
|
||||
global_exclude = set()
|
||||
if dc_exclude:
|
||||
for i in range(P1):
|
||||
for j in range(P2):
|
||||
excl_ij = _dc_outlier_indices(H[:, i, j], f)
|
||||
global_exclude.update(excl_ij.tolist())
|
||||
if extra_exclude is not None:
|
||||
global_exclude.update(np.asarray(extra_exclude, dtype=int).tolist())
|
||||
|
||||
# 每个(i,j)各自选索引
|
||||
allowed = sorted(set(range(N)) - global_exclude)
|
||||
if not allowed:
|
||||
# 兜底
|
||||
allowed = [N - 1]
|
||||
|
||||
# 边界:无需裁剪时,也只返回 allowed(已剔除 DC 极大点)
|
||||
if N < 4 or max_points is None or max_points >= N:
|
||||
sel_final = np.array(allowed, dtype=int)
|
||||
return H[sel_final], f[sel_final]
|
||||
|
||||
# 每个(i,j)各自选索引(带 exclude)
|
||||
counts = {}
|
||||
all_sel_sets = []
|
||||
for i in range(P1):
|
||||
@@ -159,26 +310,28 @@ def auto_select_multple_ports(H, freq,
|
||||
peak_window=peak_window,
|
||||
topgrad_q=topgrad_q,
|
||||
max_points=max_points,
|
||||
ensure_ends=ensure_ends)
|
||||
ensure_ends=ensure_ends,
|
||||
dc_exclude=False, # 这里已全局剔除过
|
||||
extra_exclude=global_exclude)
|
||||
# 仅保留 allowed 的
|
||||
sel = np.array([k for k in sel if k in allowed], dtype=int)
|
||||
all_sel_sets.append(sel)
|
||||
for idx in sel.tolist():
|
||||
counts[idx] = counts.get(idx, 0) + 1
|
||||
|
||||
# 并集 + 频次优先裁剪到 max_points
|
||||
union_idx = sorted(set(np.concatenate(all_sel_sets)) )
|
||||
# 并集 + 频次优先到 max_points(限定在 allowed)
|
||||
union_idx = sorted(set(np.concatenate(all_sel_sets)) & set(allowed))
|
||||
|
||||
# 如果并集不超过预算,必要时补点至 max_points(均匀抽取未选样本)
|
||||
if len(union_idx) <= max_points:
|
||||
sel_final = union_idx
|
||||
if len(sel_final) < max_points:
|
||||
remaining = sorted(set(range(N)) - set(sel_final))
|
||||
remaining = sorted(set(allowed) - set(sel_final))
|
||||
if remaining:
|
||||
need = max_points - len(sel_final)
|
||||
step = max(1, int(np.ceil(len(remaining) / need)))
|
||||
sel_final.extend(remaining[::step][:need])
|
||||
sel_final = sorted(set(sel_final))[:max_points]
|
||||
else:
|
||||
# 过多则按出现次数从高到低选,出现次数相同按索引位置靠前优先
|
||||
sorted_by_score = sorted(union_idx, key=lambda k: (-counts.get(k, 0), k))
|
||||
sel_final = sorted(sorted_by_score[:max_points])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user