176 lines
7.8 KiB
Python
176 lines
7.8 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from .basis.MultiPortOrthonormalBasis import MultiPortOrthonormalBasis
|
|
from .utils import generate_starting_poles
|
|
|
|
class VFManager():
|
|
def __init__(
|
|
self,
|
|
npoles_cplx,
|
|
freqs,
|
|
H,
|
|
model=MultiPortOrthonormalBasis,
|
|
iterations:int=5,
|
|
fit_constant:bool=True,
|
|
fit_proportional:bool=False,
|
|
dc_enforce:bool=False,
|
|
passivity_enforce:bool=True,
|
|
verbose:bool=True
|
|
):
|
|
|
|
|
|
self.freqs=freqs
|
|
self.H=H
|
|
self.iterations=iterations
|
|
self.fit_constant=fit_constant
|
|
self.fit_proportional=fit_proportional
|
|
self.dc_enforce=dc_enforce
|
|
self.passivity_enforce=passivity_enforce
|
|
self.verbose=verbose
|
|
|
|
self.nports = H.shape[1]
|
|
self.npoles_cplx = npoles_cplx
|
|
|
|
self.least_squares_condition = []
|
|
self.least_squares_row_condition = []
|
|
self.least_squares_col_condition = []
|
|
self.least_squares_rms_error = []
|
|
self.eigenval_condition = []
|
|
self.eigenval_row_condition = []
|
|
self.eigenval_col_condition = []
|
|
self.eigenval_rms_error = []
|
|
|
|
self.model_instance = None
|
|
self.model_responses_freqs = None
|
|
self.model_responses_H = None
|
|
|
|
self.model=model
|
|
|
|
def fit(self):
|
|
self.levi()
|
|
self.model_instance = self.sk_iteration()
|
|
return self.model
|
|
|
|
def levi(self):
|
|
self.poles = generate_starting_poles(self.npoles_cplx,beta_min=1e4,beta_max=self.freqs[-1]*1.1)
|
|
self.model_instance=self.model(
|
|
H=self.H,
|
|
freqs=self.freqs,
|
|
poles=self.poles,
|
|
fit_constant=self.fit_constant,
|
|
fit_proportional=self.fit_proportional,
|
|
dc_enforce=self.dc_enforce,
|
|
passivity_enforce=self.passivity_enforce
|
|
)
|
|
return self.model_instance
|
|
|
|
def sk_iteration(self):
|
|
for i in range(self.iterations):
|
|
assert self.model_instance is not None ,"Please run levi() first."
|
|
self.poles = self.model_instance.next_poles
|
|
self.weights = self.model_instance.Dt
|
|
self.model_instance = self.model(
|
|
H=self.H,
|
|
freqs=self.freqs,
|
|
poles=self.poles,
|
|
weights=self.weights,
|
|
fit_constant=self.fit_constant,
|
|
fit_proportional=self.fit_proportional,
|
|
dc_enforce=self.dc_enforce,
|
|
passivity_enforce=self.passivity_enforce
|
|
)
|
|
if self.verbose:
|
|
print(f"Iteration {i+1}/{self.iterations}")
|
|
print("A:",self.model_instance.A)
|
|
print("B:",self.model_instance.B)
|
|
print("C:",self.model_instance.C)
|
|
print("D:",self.model_instance.D)
|
|
print("next_pozles:",self.model_instance.next_poles)
|
|
print("Dt:",self.model_instance.Dt)
|
|
print("Dt/Dt_1:",np.linalg.norm(self.model_instance.Dt_Dt_1))
|
|
self.least_squares_condition.append(self.model_instance.least_squares_condition)
|
|
self.least_squares_row_condition.append(self.model_instance.least_squares_row_condition)
|
|
self.least_squares_col_condition.append(self.model_instance.least_squares_col_condition)
|
|
self.least_squares_rms_error.append(self.model_instance.least_squares_rms_error)
|
|
self.eigenval_condition.append(self.model_instance.eigenval_condition)
|
|
self.eigenval_row_condition.append(self.model_instance.eigenval_row_condition)
|
|
self.eigenval_col_condition.append(self.model_instance.eigenval_col_condition)
|
|
self.eigenval_rms_error.append(self.model_instance.eigenval_rms_error)
|
|
return self.model_instance
|
|
|
|
def plot_metrics(self,show:bool=True,save_path=None):
|
|
plt.figure(figsize=(16, 12))
|
|
plt.subplot(4, 2, 1)
|
|
plt.plot(self.least_squares_condition, label='Least Squares Condition')
|
|
plt.legend()
|
|
plt.subplot(4, 2, 2)
|
|
plt.plot(self.least_squares_row_condition, label='Least Squares Row Condition')
|
|
plt.legend()
|
|
plt.subplot(4, 2, 3)
|
|
plt.plot(self.least_squares_col_condition, label='Least Squares Col Condition')
|
|
plt.legend()
|
|
plt.subplot(4, 2, 4)
|
|
plt.plot(self.least_squares_rms_error, label='Least Squares RMS Error')
|
|
plt.legend()
|
|
plt.subplot(4, 2, 5)
|
|
plt.plot(self.eigenval_condition, label='Eigenvalue Condition')
|
|
plt.legend()
|
|
plt.subplot(4, 2, 6)
|
|
plt.plot(self.eigenval_row_condition, label='Eigenvalue Row Condition')
|
|
plt.legend()
|
|
plt.subplot(4, 2, 7)
|
|
plt.plot(self.eigenval_col_condition, label='Eigenvalue Col Condition')
|
|
plt.legend()
|
|
plt.subplot(4, 2, 8)
|
|
plt.plot(self.eigenval_rms_error, label='Eigenvalue RMS Error')
|
|
plt.legend()
|
|
if show:
|
|
plt.show()
|
|
if save_path is not None:
|
|
if self.verbose:
|
|
print(f"Saving metrics plot to {save_path}/fitting_metrics.png")
|
|
plt.savefig(f"{save_path}/fitting_metrics.png")
|
|
|
|
def plot_model_responses(self,show:bool=True,save_path=None):
|
|
assert self.model_responses_freqs is not None and self.model_responses_H is not None, "Please run get_model_responses() first."
|
|
for i in range(self.nports):
|
|
for j in range(self.nports):
|
|
plt.figure(figsize=(12, 6))
|
|
plt.subplot(2, 2, 1)
|
|
plt.plot(self.freqs, np.abs(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
|
|
plt.plot(self.model_responses_freqs, np.abs(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
|
|
plt.title(f"Response i={i+1}, j={j+1}")
|
|
plt.ylabel("Magnitude")
|
|
plt.legend(loc="best")
|
|
plt.subplot(2, 2, 2)
|
|
plt.plot(self.freqs, np.angle(self.H[:,i,j],deg=True), 'o', ms=4, color='red', label='Input Samples')
|
|
plt.plot(self.model_responses_freqs, np.angle(self.model_responses_H[:,i,j],deg=True), '-', lw=2, color='k', label='Fit')
|
|
plt.title(f"Response i={i+1}, j={j+1}")
|
|
plt.ylabel("Phase (deg)")
|
|
plt.legend(loc="best")
|
|
plt.tight_layout()
|
|
plt.subplot(2, 2, 3)
|
|
plt.plot(self.freqs, np.real(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
|
|
plt.plot(self.model_responses_freqs, np.real(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
|
|
plt.title(f"Response i={i+1}, j={j+1}")
|
|
plt.ylabel("Real Part")
|
|
plt.legend(loc="best")
|
|
plt.subplot(2, 2, 4)
|
|
plt.plot(self.freqs, np.imag(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples')
|
|
plt.plot(self.model_responses_freqs, np.imag(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit')
|
|
plt.title(f"Response i={i+1}, j={j+1}")
|
|
plt.ylabel("Imag Part")
|
|
plt.legend(loc="best")
|
|
plt.tight_layout()
|
|
if show:
|
|
plt.show()
|
|
if save_path is not None:
|
|
if self.verbose:
|
|
print(f"Saving response plot for port {i+1},{j+1} to {save_path}/response_{i+1}_{j+1}.png")
|
|
plt.savefig(f"{save_path}/response_{i+1}_{j+1}.png")
|
|
|
|
def get_model_responses(self,freqs):
|
|
assert self.model_instance is not None ,"Please run levi() and sk_iteration() first."
|
|
self.model_responses_freqs = freqs
|
|
self.model_responses_H = self.model_instance.get_model_responses(freqs)
|
|
return self.model_responses_H |