init: 验证robust算法
This commit is contained in:
149
models/basic.py
Normal file
149
models/basic.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from abc import abstractmethod
|
||||
from skrf import Network
|
||||
from schemas.paramer import SimulationRequestUnit, UuidResponseUnit, SimulationResponseUnit
|
||||
from typing import List, Literal, Union, Dict
|
||||
import requests
|
||||
import time
|
||||
from utils import send_get_request
|
||||
from pydantic import BaseModel, Field
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
class ModelBasicParametersUnit(BaseModel):
|
||||
name: str
|
||||
type: Literal["number","integer","string","boolean"]
|
||||
range: List[Union[float,int,str,bool]]
|
||||
|
||||
class ModelBasicInfo(BaseModel):
|
||||
base_url = "http://localhost:8105/api/v1"
|
||||
nports: int = Field(default=2)
|
||||
cell_name: str
|
||||
template_name: str
|
||||
user_id: int = Field(default=0)
|
||||
template_version: str = Field(default="")
|
||||
|
||||
class ModelBasicDatasetUnit(BaseModel):
|
||||
nports: int = Field(default=2)
|
||||
parameters: Dict[str, Union[float, int, str, bool]]
|
||||
id: int = Field(default=0)
|
||||
result_dir: str = Field(default="")
|
||||
|
||||
@property
|
||||
def network(self) -> Network:
|
||||
try:
|
||||
network = Network(f"{self.result_dir}/{self.id}.s{self.nports}p")
|
||||
return network
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading network from {self.result_dir}: {e}")
|
||||
|
||||
@property
|
||||
def s_params(self) -> np.ndarray:
|
||||
return self.network.s
|
||||
|
||||
@property
|
||||
def y_params(self) -> np.ndarray:
|
||||
return self.network.y
|
||||
|
||||
@property
|
||||
def z_params(self) -> np.ndarray:
|
||||
return self.network.z
|
||||
|
||||
@property
|
||||
def freqs(self) -> np.ndarray:
|
||||
return self.network.f
|
||||
|
||||
|
||||
class ModelBasic:
|
||||
def __init__(self):
|
||||
self._dataset:List[ModelBasicDatasetUnit] = []
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def info(self)->ModelBasicInfo:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self)->List[ModelBasicParametersUnit]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def settings(self)->dict:
|
||||
pass
|
||||
|
||||
def sweep(self):
|
||||
parameters_list = []
|
||||
lst = [res.range for res in self.parameters]
|
||||
parameters_name = [res.name for res in self.parameters]
|
||||
result = [list(item) for item in itertools.product(*lst)]
|
||||
for res in result:
|
||||
parameters_list.append({parameters_name[i]:res[i] for i in range(len(res))})
|
||||
|
||||
for res in parameters_list:
|
||||
request_unit = SimulationRequestUnit(
|
||||
user_id=self.info.user_id,
|
||||
template_name=self.info.template_name,
|
||||
template_version=self.info.template_version,
|
||||
cell_name=self.info.cell_name,
|
||||
parameters=res,
|
||||
settings=self.settings
|
||||
)
|
||||
response = self.simulate(request_unit)
|
||||
|
||||
self._dataset.append(ModelBasicDatasetUnit(nports=self.info.nports,parameters=res, id=response.id, result_dir=response.result_path))
|
||||
|
||||
@property
|
||||
def results(self)->List[ModelBasicDatasetUnit]:
|
||||
return self._dataset
|
||||
|
||||
def export(self, path:str|None):
|
||||
if path is None:
|
||||
path = f"{self.info.cell_name}_dataset.json"
|
||||
with open(path,"w") as f:
|
||||
json.dump([dict(item) for item in self.results],f,indent=4)
|
||||
|
||||
def load(self, path:str|None):
|
||||
if path is None:
|
||||
path = f"{self.info.cell_name}_dataset.json"
|
||||
with open(path,"r") as f:
|
||||
data = json.load(f)
|
||||
self._dataset += [ModelBasicDatasetUnit(**item) for item in data]
|
||||
|
||||
def clear(self):
|
||||
self._dataset = []
|
||||
|
||||
def simulate(self,simulation_request:SimulationRequestUnit)->UuidResponseUnit:
|
||||
def send_simulate_request(url, data:list)->list[dict]:
|
||||
response = requests.post(url, json = data)
|
||||
if response.status_code not in [200,201,202]:
|
||||
raise RuntimeError(f"send_simulate_request: {response.status_code}, {response.text}")
|
||||
return response.json()
|
||||
|
||||
response = send_simulate_request(f"{self.info.base_url}/simulations/create", [dict(simulation_request)])
|
||||
|
||||
simulation_response_model:SimulationResponseUnit = SimulationResponseUnit(**(response[0]))
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
response = send_get_request(f"{self.info.base_url}/simulations/input_hash/{simulation_response_model.input_hash}")
|
||||
assert isinstance(response, dict), "Response is not a dictionary."
|
||||
uuid_model = UuidResponseUnit(**response)
|
||||
|
||||
|
||||
time.sleep(0.01) # Wait for 2 seconds before checking the status again
|
||||
|
||||
status = uuid_model.status
|
||||
while status != "completed" and status != "failed":
|
||||
time.sleep(0.01) # Wait for 2 seconds before checking again
|
||||
response = send_get_request(f"{self.info.base_url}/simulations/input_hash/{simulation_response_model.input_hash}")
|
||||
assert isinstance(response, dict), "Response is not a dictionary."
|
||||
uuid_model = UuidResponseUnit(**response)
|
||||
assert response is not None, "No response received from the server."
|
||||
status = uuid_model.status
|
||||
|
||||
if status == "failed":
|
||||
raise RuntimeError(f"Simulation failed: {uuid_model.error_message}")
|
||||
else:
|
||||
return uuid_model
|
||||
Reference in New Issue
Block a user