chore: 分离了sweep和vf部分,vf部分准备写为包
This commit is contained in:
176
core/VFManager.py
Normal file
176
core/VFManager.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from .basis.MultiPortOrthonormalBasis import MultiPortOrthonormalBasis
|
||||
from .utils import generate_starting_poles
|
||||
|
||||
class VFManager():
|
||||
def __init__(
|
||||
self,
|
||||
npoles_cplx,
|
||||
freqs,
|
||||
H,
|
||||
model=MultiPortOrthonormalBasis,
|
||||
iterations:int=5,
|
||||
fit_constant:bool=True,
|
||||
fit_proportional:bool=False,
|
||||
dc_enforce:bool=False,
|
||||
passivity_enforce:bool=True,
|
||||
verbose:bool=True
|
||||
):
|
||||
|
||||
|
||||
self.freqs=freqs
|
||||
self.H=H
|
||||
self.iterations=iterations
|
||||
self.fit_constant=fit_constant
|
||||
self.fit_proportional=fit_proportional
|
||||
self.dc_enforce=dc_enforce
|
||||
self.passivity_enforce=passivity_enforce
|
||||
self.verbose=verbose
|
||||
|
||||
self.nports = H.shape[1]
|
||||
self.npoles_cplx = npoles_cplx
|
||||
|
||||
self.least_squares_condition = []
|
||||
self.least_squares_row_condition = []
|
||||
self.least_squares_col_condition = []
|
||||
self.least_squares_rms_error = []
|
||||
self.eigenval_condition = []
|
||||
self.eigenval_row_condition = []
|
||||
self.eigenval_col_condition = []
|
||||
self.eigenval_rms_error = []
|
||||
|
||||
self.model_instance = None
|
||||
self.model_responses_freqs = None
|
||||
self.model_responses_H = None
|
||||
|
||||
self.model=model
|
||||
|
||||
def fit(self):
|
||||
self.levi()
|
||||
self.model_instance = self.sk_iteration()
|
||||
return self.model
|
||||
|
||||
def levi(self):
|
||||
self.poles = generate_starting_poles(self.npoles_cplx,beta_min=1e4,beta_max=self.freqs[-1]*1.1)
|
||||
self.model_instance=self.model(
|
||||
H=self.H,
|
||||
freqs=self.freqs,
|
||||
poles=self.poles,
|
||||
fit_constant=self.fit_constant,
|
||||
fit_proportional=self.fit_proportional,
|
||||
dc_enforce=self.dc_enforce,
|
||||
passivity_enforce=self.passivity_enforce
|
||||
)
|
||||
return self.model_instance
|
||||
|
||||
def sk_iteration(self):
|
||||
for i in range(self.iterations):
|
||||
assert self.model_instance is not None ,"Please run levi() first."
|
||||
self.poles = self.model_instance.next_poles
|
||||
self.weights = self.model_instance.Dt
|
||||
self.model_instance = self.model(
|
||||
H=self.H,
|
||||
freqs=self.freqs,
|
||||
poles=self.poles,
|
||||
weights=self.weights,
|
||||
fit_constant=self.fit_constant,
|
||||
fit_proportional=self.fit_proportional,
|
||||
dc_enforce=self.dc_enforce,
|
||||
passivity_enforce=self.passivity_enforce
|
||||
)
|
||||
if self.verbose:
|
||||
print(f"Iteration {i+1}/{self.iterations}")
|
||||
print("A:",self.model_instance.A)
|
||||
print("B:",self.model_instance.B)
|
||||
print("C:",self.model_instance.C)
|
||||
print("D:",self.model_instance.D)
|
||||
print("next_pozles:",self.model_instance.next_poles)
|
||||
print("Dt:",self.model_instance.Dt)
|
||||
print("Dt/Dt_1:",np.linalg.norm(self.model_instance.Dt_Dt_1))
|
||||
self.least_squares_condition.append(self.model_instance.least_squares_condition)
|
||||
self.least_squares_row_condition.append(self.model_instance.least_squares_row_condition)
|
||||
self.least_squares_col_condition.append(self.model_instance.least_squares_col_condition)
|
||||
self.least_squares_rms_error.append(self.model_instance.least_squares_rms_error)
|
||||
self.eigenval_condition.append(self.model_instance.eigenval_condition)
|
||||
self.eigenval_row_condition.append(self.model_instance.eigenval_row_condition)
|
||||
self.eigenval_col_condition.append(self.model_instance.eigenval_col_condition)
|
||||
self.eigenval_rms_error.append(self.model_instance.eigenval_rms_error)
|
||||
return self.model_instance
|
||||
|
||||
def plot_metrics(self,show:bool=True,save_path=None):
|
||||
plt.figure(figsize=(16, 12))
|
||||
plt.subplot(4, 2, 1)
|
||||
plt.plot(self.least_squares_condition, label='Least Squares Condition')
|
||||
plt.legend()
|
||||
plt.subplot(4, 2, 2)
|
||||
plt.plot(self.least_squares_row_condition, label='Least Squares Row Condition')
|
||||
plt.legend()
|
||||
plt.subplot(4, 2, 3)
|
||||
plt.plot(self.least_squares_col_condition, label='Least Squares Col Condition')
|
||||
plt.legend()
|
||||
plt.subplot(4, 2, 4)
|
||||
plt.plot(self.least_squares_rms_error, label='Least Squares RMS Error')
|
||||
plt.legend()
|
||||
plt.subplot(4, 2, 5)
|
||||
plt.plot(self.eigenval_condition, label='Eigenvalue Condition')
|
||||
plt.legend()
|
||||
plt.subplot(4, 2, 6)
|
||||
plt.plot(self.eigenval_row_condition, label='Eigenvalue Row Condition')
|
||||
plt.legend()
|
||||
plt.subplot(4, 2, 7)
|
||||
plt.plot(self.eigenval_col_condition, label='Eigenvalue Col Condition')
|
||||
plt.legend()
|
||||
plt.subplot(4, 2, 8)
|
||||
plt.plot(self.eigenval_rms_error, label='Eigenvalue RMS Error')
|
||||
plt.legend()
|
||||
if show:
|
||||
plt.show()
|
||||
if save_path is not None:
|
||||
if self.verbose:
|
||||
print(f"Saving metrics plot to {save_path}/fitting_metrics.png")
|
||||
plt.savefig(f"{save_path}/fitting_metrics.png")
|
||||
|
||||
def plot_model_responses(self,show:bool=True,save_path=None):
|
||||
assert self.model_responses_freqs is not None and self.model_responses_H is not None, "Please run get_model_responses() first."
|
||||
for i in range(self.nports):
|
||||
for j in range(self.nports):
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.subplot(2, 2, 1)
|
||||
plt.plot(self.freqs, np.abs(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
|
||||
plt.plot(self.model_responses_freqs, np.abs(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
|
||||
plt.title(f"Response i={i+1}, j={j+1}")
|
||||
plt.ylabel("Magnitude")
|
||||
plt.legend(loc="best")
|
||||
plt.subplot(2, 2, 2)
|
||||
plt.plot(self.freqs, np.angle(self.H[:,i,j],deg=True), 'o', ms=4, color='red', label='Input Samples')
|
||||
plt.plot(self.model_responses_freqs, np.angle(self.model_responses_H[:,i,j],deg=True), '-', lw=2, color='k', label='Fit')
|
||||
plt.title(f"Response i={i+1}, j={j+1}")
|
||||
plt.ylabel("Phase (deg)")
|
||||
plt.legend(loc="best")
|
||||
plt.tight_layout()
|
||||
plt.subplot(2, 2, 3)
|
||||
plt.plot(self.freqs, np.real(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
|
||||
plt.plot(self.model_responses_freqs, np.real(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
|
||||
plt.title(f"Response i={i+1}, j={j+1}")
|
||||
plt.ylabel("Real Part")
|
||||
plt.legend(loc="best")
|
||||
plt.subplot(2, 2, 4)
|
||||
plt.plot(self.freqs, np.imag(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
|
||||
plt.plot(self.model_responses_freqs, np.imag(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
|
||||
plt.title(f"Response i={i+1}, j={j+1}")
|
||||
plt.ylabel("Imag Part")
|
||||
plt.legend(loc="best")
|
||||
plt.tight_layout()
|
||||
if show:
|
||||
plt.show()
|
||||
if save_path is not None:
|
||||
if self.verbose:
|
||||
print(f"Saving response plot for port {i+1},{j+1} to {save_path}/response_{i+1}_{j+1}.png")
|
||||
plt.savefig(f"{save_path}/response_{i+1}_{j+1}.png")
|
||||
|
||||
def get_model_responses(self,freqs):
|
||||
assert self.model_instance is not None ,"Please run levi() and sk_iteration() first."
|
||||
self.model_responses_freqs = freqs
|
||||
self.model_responses_H = self.model_instance.get_model_responses(freqs)
|
||||
return self.model_responses_H
|
||||
Reference in New Issue
Block a user