import matplotlib.pyplot as plt import numpy as np import time from models import MLP import torch from torch import nn import mpldatacursor time_point = 1 low=0.82 high=0.83 model = MLP.MLP() model.load_state_dict(torch.load('lstm_model.pt')) # 创建初始的折线图数据 learning_rate = 0.01 # learning_rate=1 criterion = nn.MSELoss() # criterion=nn.L1Loss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) x = np.array([1]) y = np.array([model.fc0.weight.tolist()[0][0]]) # 绘制初始的折线图 fig, ax = plt.subplots() line, = ax.plot(x, y, '-o') # 显示初始的折线图 plt.show(block=False) # scatter = ax.scatter(x, y) # # 定义标注文本 # def formatter(**kwargs): # x, y = kwargs['x'], kwargs['y'] # return f"x={x:.2f}\ny={y:.2f}" # # 配置悬浮窗 # cursor = mpldatacursor.datacursor( # scatter, # formatter=formatter, # display='multiple', # draggable=True, # bbox=dict(fc='white', alpha=0.9), # arrowprops=dict(arrowstyle='->', connectionstyle='arc3'), # ) # 定义自定义事件类型 class AddPointEvent: def __init__(self, x, y): self.x = x self.y = y # 定义事件处理函数,用于处理自定义事件 def on_add_point(event): # 添加新的点 print(f"x:{event.x},y:{event.y}") new_x = np.array([event.x]) new_y = np.array([event.y]) # 使用set_xdata()和set_ydata()方法更新折线图的数据 line.set_xdata(np.append(line.get_xdata(), new_x)) line.set_ydata(np.append(line.get_ydata(), new_y)) # 更新折线图的坐标轴范围 ax.relim() ax.autoscale_view() ax.annotate("hi", xy=(event.x, event.y), xytext=(event.x, event.y+0.5), ha='center', va='bottom') # 更新折线图 fig.canvas.draw() # 连接自定义事件到事件处理函数上 cid = fig.canvas.mpl_connect(AddPointEvent, on_add_point) # 定义计时器回调函数,用于触发自定义事件 def timer_callback(): # 生成一个随机的点,并触发自定义事件 global time_point time_point=time_point+1 y=np.random.uniform(low=50, high=101) y_hat = y*np.random.uniform(low=low, high=high) batch_sequences = np.array([[y,y_hat]]) # print(batch_sequences.shape) inputs = torch.tensor(batch_sequences[:, :-1], dtype=torch.float32) targets = torch.tensor(batch_sequences[:,-1], dtype=torch.float32) # 前向传播 outputs = model(inputs) # print(f"targets {targets}") # print(f"input:{inputs}") # print(f"outputs {outputs}") outputs=outputs.reshape(targets.shape) # print(outputs.shape) # print(f"output:{outputs.shape}") # print(f"target:{targets.shape}") loss = criterion(outputs, targets) # print(f"loss:{loss}") # loss=loss*1000 # 反向传播和优化 loss.backward() optimizer.step() optimizer.zero_grad() new_y = model.fc0.weight.tolist()[0][0] new_event = AddPointEvent(time_point, new_y) fig.canvas.callbacks.process(AddPointEvent, new_event) # 重新启动计时器 timer = fig.canvas.new_timer(interval=1000) timer.add_callback(timer_callback) timer.start() # 启动计时器 timer = fig.canvas.new_timer(interval=1000) timer.add_callback(timer_callback) timer.start() def on_hover(event): # 检查是否有Artist在鼠标位置 if event.inaxes == ax: # 获取鼠标位置 x, y = event.xdata, event.ydata # 显示自定义信息 ax.annotate(f'({x:.2f}, {y:.2f})', (x, y)) # 将on_hover函数绑定到Figure的鼠标移动事件上 # fig.canvas.mpl_connect('motion_notify_event', on_hover) # 显示动态更新的折线图,并保持程序运行 plt.show()