You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
3.7 KiB
Python

2 years ago
import matplotlib.pyplot as plt
import numpy as np
import time
from models import MLP
import torch
from torch import nn
import mpldatacursor
time_point = 1
low=0.82
high=0.83
model = MLP.MLP()
model.load_state_dict(torch.load('lstm_model.pt'))
# 创建初始的折线图数据
learning_rate = 0.01
# learning_rate=1
criterion = nn.MSELoss()
# criterion=nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
x = np.array([1])
y = np.array([model.fc0.weight.tolist()[0][0]])
# 绘制初始的折线图
fig, ax = plt.subplots()
line, = ax.plot(x, y, '-o')
# 显示初始的折线图
plt.show(block=False)
# scatter = ax.scatter(x, y)
# # 定义标注文本
# def formatter(**kwargs):
# x, y = kwargs['x'], kwargs['y']
# return f"x={x:.2f}\ny={y:.2f}"
# # 配置悬浮窗
# cursor = mpldatacursor.datacursor(
# scatter,
# formatter=formatter,
# display='multiple',
# draggable=True,
# bbox=dict(fc='white', alpha=0.9),
# arrowprops=dict(arrowstyle='->', connectionstyle='arc3'),
# )
# 定义自定义事件类型
class AddPointEvent:
def __init__(self, x, y):
self.x = x
self.y = y
# 定义事件处理函数,用于处理自定义事件
def on_add_point(event):
# 添加新的点
print(f"x:{event.x},y:{event.y}")
new_x = np.array([event.x])
new_y = np.array([event.y])
# 使用set_xdata()和set_ydata()方法更新折线图的数据
line.set_xdata(np.append(line.get_xdata(), new_x))
line.set_ydata(np.append(line.get_ydata(), new_y))
# 更新折线图的坐标轴范围
ax.relim()
ax.autoscale_view()
ax.annotate("hi", xy=(event.x, event.y), xytext=(event.x, event.y+0.5), ha='center', va='bottom')
# 更新折线图
fig.canvas.draw()
# 连接自定义事件到事件处理函数上
cid = fig.canvas.mpl_connect(AddPointEvent, on_add_point)
# 定义计时器回调函数,用于触发自定义事件
def timer_callback():
# 生成一个随机的点,并触发自定义事件
global time_point
time_point=time_point+1
y=np.random.uniform(low=50, high=101)
y_hat = y*np.random.uniform(low=low, high=high)
batch_sequences = np.array([[y,y_hat]])
# print(batch_sequences.shape)
inputs = torch.tensor(batch_sequences[:, :-1], dtype=torch.float32)
targets = torch.tensor(batch_sequences[:,-1], dtype=torch.float32)
# 前向传播
outputs = model(inputs)
# print(f"targets {targets}")
# print(f"input:{inputs}")
# print(f"outputs {outputs}")
outputs=outputs.reshape(targets.shape)
# print(outputs.shape)
# print(f"output:{outputs.shape}")
# print(f"target:{targets.shape}")
loss = criterion(outputs, targets)
# print(f"loss:{loss}")
# loss=loss*1000
# 反向传播和优化
loss.backward()
optimizer.step()
optimizer.zero_grad()
new_y = model.fc0.weight.tolist()[0][0]
new_event = AddPointEvent(time_point, new_y)
fig.canvas.callbacks.process(AddPointEvent, new_event)
# 重新启动计时器
timer = fig.canvas.new_timer(interval=1000)
timer.add_callback(timer_callback)
timer.start()
# 启动计时器
timer = fig.canvas.new_timer(interval=1000)
timer.add_callback(timer_callback)
timer.start()
def on_hover(event):
# 检查是否有Artist在鼠标位置
if event.inaxes == ax:
# 获取鼠标位置
x, y = event.xdata, event.ydata
# 显示自定义信息
ax.annotate(f'({x:.2f}, {y:.2f})', (x, y))
# 将on_hover函数绑定到Figure的鼠标移动事件上
# fig.canvas.mpl_connect('motion_notify_event', on_hover)
# 显示动态更新的折线图,并保持程序运行
plt.show()