import matplotlib.pyplot as plt import numpy as np from .basis.MultiPortOrthonormalBasis import MultiPortOrthonormalBasis from .utils import generate_starting_poles class VFManager(): def __init__( self, npoles_cplx, freqs, H, model=MultiPortOrthonormalBasis, iterations:int=5, fit_constant:bool=True, fit_proportional:bool=False, dc_enforce:bool=False, passivity_enforce:bool=True, verbose:bool=True ): self.freqs=freqs self.H=H self.iterations=iterations self.fit_constant=fit_constant self.fit_proportional=fit_proportional self.dc_enforce=dc_enforce self.passivity_enforce=passivity_enforce self.verbose=verbose self.nports = H.shape[1] self.npoles_cplx = npoles_cplx self.least_squares_condition = [] self.least_squares_row_condition = [] self.least_squares_col_condition = [] self.least_squares_rms_error = [] self.eigenval_condition = [] self.eigenval_row_condition = [] self.eigenval_col_condition = [] self.eigenval_rms_error = [] self.model_instance = None self.model_responses_freqs = None self.model_responses_H = None self.model=model def fit(self): self.levi() self.model_instance = self.sk_iteration() return self.model def levi(self): self.poles = generate_starting_poles(self.npoles_cplx,beta_min=1e4,beta_max=self.freqs[-1]*1.1) self.model_instance=self.model( H=self.H, freqs=self.freqs, poles=self.poles, fit_constant=self.fit_constant, fit_proportional=self.fit_proportional, dc_enforce=self.dc_enforce, passivity_enforce=self.passivity_enforce ) return self.model_instance def sk_iteration(self): for i in range(self.iterations): assert self.model_instance is not None ,"Please run levi() first." self.poles = self.model_instance.next_poles self.weights = self.model_instance.Dt self.model_instance = self.model( H=self.H, freqs=self.freqs, poles=self.poles, weights=self.weights, fit_constant=self.fit_constant, fit_proportional=self.fit_proportional, dc_enforce=self.dc_enforce, passivity_enforce=self.passivity_enforce ) if self.verbose: print(f"Iteration {i+1}/{self.iterations}") print("A:",self.model_instance.A) print("B:",self.model_instance.B) print("C:",self.model_instance.C) print("D:",self.model_instance.D) print("next_pozles:",self.model_instance.next_poles) print("Dt:",self.model_instance.Dt) print("Dt/Dt_1:",np.linalg.norm(self.model_instance.Dt_Dt_1)) self.least_squares_condition.append(self.model_instance.least_squares_condition) self.least_squares_row_condition.append(self.model_instance.least_squares_row_condition) self.least_squares_col_condition.append(self.model_instance.least_squares_col_condition) self.least_squares_rms_error.append(self.model_instance.least_squares_rms_error) self.eigenval_condition.append(self.model_instance.eigenval_condition) self.eigenval_row_condition.append(self.model_instance.eigenval_row_condition) self.eigenval_col_condition.append(self.model_instance.eigenval_col_condition) self.eigenval_rms_error.append(self.model_instance.eigenval_rms_error) return self.model_instance def plot_metrics(self,show:bool=True,save_path=None): plt.figure(figsize=(16, 12)) plt.subplot(4, 2, 1) plt.plot(self.least_squares_condition, label='Least Squares Condition') plt.legend() plt.subplot(4, 2, 2) plt.plot(self.least_squares_row_condition, label='Least Squares Row Condition') plt.legend() plt.subplot(4, 2, 3) plt.plot(self.least_squares_col_condition, label='Least Squares Col Condition') plt.legend() plt.subplot(4, 2, 4) plt.plot(self.least_squares_rms_error, label='Least Squares RMS Error') plt.legend() plt.subplot(4, 2, 5) plt.plot(self.eigenval_condition, label='Eigenvalue Condition') plt.legend() plt.subplot(4, 2, 6) plt.plot(self.eigenval_row_condition, label='Eigenvalue Row Condition') plt.legend() plt.subplot(4, 2, 7) plt.plot(self.eigenval_col_condition, label='Eigenvalue Col Condition') plt.legend() plt.subplot(4, 2, 8) plt.plot(self.eigenval_rms_error, label='Eigenvalue RMS Error') plt.legend() if show: plt.show() if save_path is not None: if self.verbose: print(f"Saving metrics plot to {save_path}/fitting_metrics.png") plt.savefig(f"{save_path}/fitting_metrics.png") def plot_model_responses(self,show:bool=True,save_path=None): assert self.model_responses_freqs is not None and self.model_responses_H is not None, "Please run get_model_responses() first." for i in range(self.nports): for j in range(self.nports): plt.figure(figsize=(12, 6)) plt.subplot(2, 2, 1) plt.plot(self.freqs, np.abs(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples') plt.plot(self.model_responses_freqs, np.abs(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit') plt.title(f"Response i={i+1}, j={j+1}") plt.ylabel("Magnitude") plt.legend(loc="best") plt.subplot(2, 2, 2) plt.plot(self.freqs, np.angle(self.H[:,i,j],deg=True), 'o', ms=4, color='red', label='Input Samples') plt.plot(self.model_responses_freqs, np.angle(self.model_responses_H[:,i,j],deg=True), '-', lw=2, color='k', label='Fit') plt.title(f"Response i={i+1}, j={j+1}") plt.ylabel("Phase (deg)") plt.legend(loc="best") plt.tight_layout() plt.subplot(2, 2, 3) plt.plot(self.freqs, np.real(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples') plt.plot(self.model_responses_freqs, np.real(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit') plt.title(f"Response i={i+1}, j={j+1}") plt.ylabel("Real Part") plt.legend(loc="best") plt.subplot(2, 2, 4) plt.plot(self.freqs, np.imag(self.H[:,i,j]), 'o', ms=4, color='red', label='Input Samples') plt.plot(self.model_responses_freqs, np.imag(self.model_responses_H[:,i,j]), '-', lw=2, color='k', label='Fit') plt.title(f"Response i={i+1}, j={j+1}") plt.ylabel("Imag Part") plt.legend(loc="best") plt.tight_layout() if show: plt.show() if save_path is not None: if self.verbose: print(f"Saving response plot for port {i+1},{j+1} to {save_path}/response_{i+1}_{j+1}.png") plt.savefig(f"{save_path}/response_{i+1}_{j+1}.png") def get_model_responses(self,freqs): assert self.model_instance is not None ,"Please run levi() and sk_iteration() first." self.model_responses_freqs = freqs self.model_responses_H = self.model_instance.get_model_responses(freqs) return self.model_responses_H