You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
143 lines
3.7 KiB
Python
143 lines
3.7 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import time
|
|
from models import MLP
|
|
import torch
|
|
from torch import nn
|
|
import mpldatacursor
|
|
|
|
time_point = 1
|
|
low=0.82
|
|
high=0.83
|
|
model = MLP.MLP()
|
|
model.load_state_dict(torch.load('lstm_model.pt'))
|
|
# 创建初始的折线图数据
|
|
learning_rate = 0.01
|
|
# learning_rate=1
|
|
criterion = nn.MSELoss()
|
|
# criterion=nn.L1Loss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
|
|
|
|
|
x = np.array([1])
|
|
y = np.array([model.fc0.weight.tolist()[0][0]])
|
|
|
|
# 绘制初始的折线图
|
|
fig, ax = plt.subplots()
|
|
line, = ax.plot(x, y, '-o')
|
|
|
|
|
|
|
|
# 显示初始的折线图
|
|
plt.show(block=False)
|
|
|
|
# scatter = ax.scatter(x, y)
|
|
# # 定义标注文本
|
|
# def formatter(**kwargs):
|
|
# x, y = kwargs['x'], kwargs['y']
|
|
# return f"x={x:.2f}\ny={y:.2f}"
|
|
# # 配置悬浮窗
|
|
# cursor = mpldatacursor.datacursor(
|
|
# scatter,
|
|
# formatter=formatter,
|
|
# display='multiple',
|
|
# draggable=True,
|
|
# bbox=dict(fc='white', alpha=0.9),
|
|
# arrowprops=dict(arrowstyle='->', connectionstyle='arc3'),
|
|
# )
|
|
|
|
# 定义自定义事件类型
|
|
class AddPointEvent:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
# 定义事件处理函数,用于处理自定义事件
|
|
def on_add_point(event):
|
|
# 添加新的点
|
|
print(f"x:{event.x},y:{event.y}")
|
|
new_x = np.array([event.x])
|
|
new_y = np.array([event.y])
|
|
|
|
# 使用set_xdata()和set_ydata()方法更新折线图的数据
|
|
line.set_xdata(np.append(line.get_xdata(), new_x))
|
|
line.set_ydata(np.append(line.get_ydata(), new_y))
|
|
|
|
# 更新折线图的坐标轴范围
|
|
ax.relim()
|
|
ax.autoscale_view()
|
|
ax.annotate("hi", xy=(event.x, event.y), xytext=(event.x, event.y+0.5), ha='center', va='bottom')
|
|
|
|
# 更新折线图
|
|
fig.canvas.draw()
|
|
|
|
# 连接自定义事件到事件处理函数上
|
|
cid = fig.canvas.mpl_connect(AddPointEvent, on_add_point)
|
|
|
|
# 定义计时器回调函数,用于触发自定义事件
|
|
def timer_callback():
|
|
# 生成一个随机的点,并触发自定义事件
|
|
global time_point
|
|
time_point=time_point+1
|
|
y=np.random.uniform(low=50, high=101)
|
|
y_hat = y*np.random.uniform(low=low, high=high)
|
|
batch_sequences = np.array([[y,y_hat]])
|
|
# print(batch_sequences.shape)
|
|
inputs = torch.tensor(batch_sequences[:, :-1], dtype=torch.float32)
|
|
|
|
targets = torch.tensor(batch_sequences[:,-1], dtype=torch.float32)
|
|
|
|
|
|
# 前向传播
|
|
outputs = model(inputs)
|
|
# print(f"targets {targets}")
|
|
|
|
# print(f"input:{inputs}")
|
|
|
|
# print(f"outputs {outputs}")
|
|
|
|
outputs=outputs.reshape(targets.shape)
|
|
# print(outputs.shape)
|
|
# print(f"output:{outputs.shape}")
|
|
# print(f"target:{targets.shape}")
|
|
loss = criterion(outputs, targets)
|
|
# print(f"loss:{loss}")
|
|
# loss=loss*1000
|
|
# 反向传播和优化
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
new_y = model.fc0.weight.tolist()[0][0]
|
|
|
|
|
|
new_event = AddPointEvent(time_point, new_y)
|
|
fig.canvas.callbacks.process(AddPointEvent, new_event)
|
|
|
|
# 重新启动计时器
|
|
timer = fig.canvas.new_timer(interval=1000)
|
|
timer.add_callback(timer_callback)
|
|
timer.start()
|
|
|
|
# 启动计时器
|
|
timer = fig.canvas.new_timer(interval=1000)
|
|
timer.add_callback(timer_callback)
|
|
timer.start()
|
|
|
|
|
|
def on_hover(event):
|
|
# 检查是否有Artist在鼠标位置
|
|
if event.inaxes == ax:
|
|
# 获取鼠标位置
|
|
x, y = event.xdata, event.ydata
|
|
# 显示自定义信息
|
|
ax.annotate(f'({x:.2f}, {y:.2f})', (x, y))
|
|
|
|
# 将on_hover函数绑定到Figure的鼠标移动事件上
|
|
# fig.canvas.mpl_connect('motion_notify_event', on_hover)
|
|
|
|
# 显示动态更新的折线图,并保持程序运行
|
|
plt.show()
|
|
|