167 lines
5.2 KiB
Python
167 lines
5.2 KiB
Python
# metrics_logger.py (with colorlog console)
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from logging.handlers import RotatingFileHandler
|
||
from datetime import datetime, timezone
|
||
|
||
# --- optional colorlog ---
|
||
try:
|
||
from colorlog import ColoredFormatter
|
||
HAVE_COLORLOG = True
|
||
except Exception:
|
||
HAVE_COLORLOG = False
|
||
|
||
|
||
# ---------- JSON 格式化器 ----------
|
||
class JSONMetricsFormatter(logging.Formatter):
|
||
"""将记录格式化为一行 JSON(JSONL)"""
|
||
def format(self, record: logging.LogRecord) -> str:
|
||
payload = {
|
||
"ts": datetime.now(timezone.utc).isoformat(),
|
||
"level": record.levelname,
|
||
"event": getattr(record, "event", "metrics"),
|
||
"dataset": {"W": getattr(record, "W", None), "L": getattr(record, "L", None)},
|
||
"rms": getattr(record, "rms", None),
|
||
"rms_l1": getattr(record, "rms_l1", None),
|
||
"rms_lmax": getattr(record, "rms_lmax", None),
|
||
}
|
||
if hasattr(record, "step") and record.step is not None:
|
||
payload["step"] = record.step
|
||
if hasattr(record, "tag") and record.tag:
|
||
payload["tag"] = record.tag
|
||
if record.msg and record.msg not in ("", None):
|
||
payload["msg"] = record.getMessage()
|
||
return json.dumps(payload, ensure_ascii=False)
|
||
|
||
|
||
def _build_human_message(
|
||
W: float, L: float, rms: float, rms_l1: float, rms_lmax: float,
|
||
step: int | None = None, tag: str | None = None
|
||
) -> str:
|
||
parts = [
|
||
f"Dataset[W={W}, L={L}]",
|
||
f"RMS={rms:.12e}",
|
||
f"L1={rms_l1:.12e}",
|
||
f"Lmax={rms_lmax:.12e}",
|
||
]
|
||
if step is not None:
|
||
parts.append(f"step={step}")
|
||
if tag:
|
||
parts.append(f"tag={tag}")
|
||
return " | ".join(parts)
|
||
|
||
|
||
# ---------- 日志器工厂 ----------
|
||
def get_metrics_logger(
|
||
name: str = "metrics",
|
||
log_dir: str = "./logs",
|
||
json_filename: str = "metrics.jsonl",
|
||
level: int = logging.INFO,
|
||
max_bytes: int = 10 * 1024 * 1024,
|
||
backup_count: int = 3,
|
||
use_colorlog: bool = True,
|
||
) -> logging.Logger:
|
||
"""
|
||
创建同时输出到控制台(可选彩色)和 JSONL 文件的日志器。
|
||
- 控制台:彩色(colorlog 存在且启用)或普通
|
||
- 文件:结构化 JSONL
|
||
"""
|
||
logger = logging.getLogger(name)
|
||
logger.setLevel(level)
|
||
logger.propagate = False
|
||
|
||
if not logger.handlers:
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
# ---- 控制台 handler ----
|
||
ch = logging.StreamHandler()
|
||
ch.setLevel(level)
|
||
if HAVE_COLORLOG and use_colorlog:
|
||
ch.setFormatter(
|
||
ColoredFormatter(
|
||
fmt="%(asctime)s | %(log_color)s%(levelname)-5s%(reset)s | %(message)s",
|
||
datefmt="%Y-%m-%d %H:%M:%S",
|
||
log_colors={
|
||
"DEBUG": "cyan",
|
||
"INFO": "green",
|
||
"WARNING": "yellow",
|
||
"ERROR": "red",
|
||
"CRITICAL": "bold_red",
|
||
},
|
||
secondary_log_colors={}, # 需要给 message 局部再上色时可用
|
||
style="%",
|
||
)
|
||
)
|
||
else:
|
||
ch.setFormatter(logging.Formatter(
|
||
fmt="%(asctime)s | %(levelname)-5s | %(message)s",
|
||
datefmt="%Y-%m-%d %H:%M:%S",
|
||
))
|
||
logger.addHandler(ch)
|
||
|
||
# ---- JSONL 文件 handler(带滚动)----
|
||
fh = RotatingFileHandler(
|
||
filename=os.path.join(log_dir, json_filename),
|
||
maxBytes=max_bytes,
|
||
backupCount=backup_count,
|
||
encoding="utf-8",
|
||
)
|
||
fh.setLevel(level)
|
||
fh.setFormatter(JSONMetricsFormatter())
|
||
logger.addHandler(fh)
|
||
|
||
return logger
|
||
|
||
|
||
# ---------- 便捷记录函数 ----------
|
||
def log_metrics(
|
||
logger: logging.Logger,
|
||
W: float,
|
||
L: float,
|
||
rms: float,
|
||
rms_l1: float,
|
||
rms_lmax: float,
|
||
*,
|
||
step: int | None = None,
|
||
tag: str | None = None,
|
||
msg: str | None = None,
|
||
) -> None:
|
||
"""
|
||
以结构化方式记录一次度量。
|
||
- 终端(可彩色)显示:人类可读的一行
|
||
- 文件:结构化 JSONL(携带全部字段)
|
||
"""
|
||
human_msg = msg or _build_human_message(W, L, rms, rms_l1, rms_lmax, step, tag)
|
||
extra = {
|
||
"event": "metrics",
|
||
"W": float(W),
|
||
"L": float(L),
|
||
"rms": float(rms),
|
||
"rms_l1": float(rms_l1),
|
||
"rms_lmax": float(rms_lmax),
|
||
"step": step,
|
||
"tag": tag,
|
||
}
|
||
logger.info(human_msg, extra=extra)
|
||
|
||
|
||
# ---------- 可选:解析旧日志行(你给的样例) ----------
|
||
LEGACY_LINE_RE = re.compile(
|
||
r"Dataset:\s*\{[^}]*['\"]?W['\"]?\s*:\s*([0-9.+\-eE]+)\s*,\s*['\"]?L['\"]?\s*:\s*([0-9.+\-eE]+)\s*\}\s*,\s*"
|
||
r"RMS Error:\s*([0-9.+\-eE]+)\s*,\s*RMS Error L1:\s*([0-9.+\-eE]+)\s*,\s*RMS Error Lmax:\s*([0-9.+\-eE]+)"
|
||
)
|
||
|
||
def parse_legacy_line(line: str):
|
||
m = LEGACY_LINE_RE.search(line)
|
||
if not m:
|
||
return None
|
||
W, L, rms, rms_l1, rms_lmax = map(float, m.groups())
|
||
return {"W": W, "L": L, "rms": rms, "rms_l1": rms_l1, "rms_lmax": rms_lmax}
|
||
|
||
def ingest_legacy_line(logger: logging.Logger, line: str, **kwargs):
|
||
parsed = parse_legacy_line(line)
|
||
if parsed:
|
||
log_metrics(logger, **parsed, **kwargs)
|