You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
3.7 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
import time
from models import MLP
import torch
from torch import nn
import mpldatacursor
time_point = 1
low=0.82
high=0.83
model = MLP.MLP()
model.load_state_dict(torch.load('lstm_model.pt'))
# 创建初始的折线图数据
learning_rate = 0.01
# learning_rate=1
criterion = nn.MSELoss()
# criterion=nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
x = np.array([1])
y = np.array([model.fc0.weight.tolist()[0][0]])
# 绘制初始的折线图
fig, ax = plt.subplots()
line, = ax.plot(x, y, '-o')
# 显示初始的折线图
plt.show(block=False)
# scatter = ax.scatter(x, y)
# # 定义标注文本
# def formatter(**kwargs):
# x, y = kwargs['x'], kwargs['y']
# return f"x={x:.2f}\ny={y:.2f}"
# # 配置悬浮窗
# cursor = mpldatacursor.datacursor(
# scatter,
# formatter=formatter,
# display='multiple',
# draggable=True,
# bbox=dict(fc='white', alpha=0.9),
# arrowprops=dict(arrowstyle='->', connectionstyle='arc3'),
# )
# 定义自定义事件类型
class AddPointEvent:
def __init__(self, x, y):
self.x = x
self.y = y
# 定义事件处理函数,用于处理自定义事件
def on_add_point(event):
# 添加新的点
print(f"x:{event.x},y:{event.y}")
new_x = np.array([event.x])
new_y = np.array([event.y])
# 使用set_xdata()和set_ydata()方法更新折线图的数据
line.set_xdata(np.append(line.get_xdata(), new_x))
line.set_ydata(np.append(line.get_ydata(), new_y))
# 更新折线图的坐标轴范围
ax.relim()
ax.autoscale_view()
ax.annotate("hi", xy=(event.x, event.y), xytext=(event.x, event.y+0.5), ha='center', va='bottom')
# 更新折线图
fig.canvas.draw()
# 连接自定义事件到事件处理函数上
cid = fig.canvas.mpl_connect(AddPointEvent, on_add_point)
# 定义计时器回调函数,用于触发自定义事件
def timer_callback():
# 生成一个随机的点,并触发自定义事件
global time_point
time_point=time_point+1
y=np.random.uniform(low=50, high=101)
y_hat = y*np.random.uniform(low=low, high=high)
batch_sequences = np.array([[y,y_hat]])
# print(batch_sequences.shape)
inputs = torch.tensor(batch_sequences[:, :-1], dtype=torch.float32)
targets = torch.tensor(batch_sequences[:,-1], dtype=torch.float32)
# 前向传播
outputs = model(inputs)
# print(f"targets {targets}")
# print(f"input:{inputs}")
# print(f"outputs {outputs}")
outputs=outputs.reshape(targets.shape)
# print(outputs.shape)
# print(f"output:{outputs.shape}")
# print(f"target:{targets.shape}")
loss = criterion(outputs, targets)
# print(f"loss:{loss}")
# loss=loss*1000
# 反向传播和优化
loss.backward()
optimizer.step()
optimizer.zero_grad()
new_y = model.fc0.weight.tolist()[0][0]
new_event = AddPointEvent(time_point, new_y)
fig.canvas.callbacks.process(AddPointEvent, new_event)
# 重新启动计时器
timer = fig.canvas.new_timer(interval=1000)
timer.add_callback(timer_callback)
timer.start()
# 启动计时器
timer = fig.canvas.new_timer(interval=1000)
timer.add_callback(timer_callback)
timer.start()
def on_hover(event):
# 检查是否有Artist在鼠标位置
if event.inaxes == ax:
# 获取鼠标位置
x, y = event.xdata, event.ydata
# 显示自定义信息
ax.annotate(f'({x:.2f}, {y:.2f})', (x, y))
# 将on_hover函数绑定到Figure的鼠标移动事件上
# fig.canvas.mpl_connect('motion_notify_event', on_hover)
# 显示动态更新的折线图,并保持程序运行
plt.show()