gwx
parent
77c72f7cd9
commit
55e1ac7653
@ -1 +1,3 @@
|
||||
{}
|
||||
{
|
||||
"attachmentFolderPath": "./"
|
||||
}
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,142 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import time
|
||||
from models import MLP
|
||||
import torch
|
||||
from torch import nn
|
||||
import mpldatacursor
|
||||
|
||||
time_point = 1
|
||||
low=0.82
|
||||
high=0.83
|
||||
model = MLP.MLP()
|
||||
model.load_state_dict(torch.load('lstm_model.pt'))
|
||||
# 创建初始的折线图数据
|
||||
learning_rate = 0.01
|
||||
# learning_rate=1
|
||||
criterion = nn.MSELoss()
|
||||
# criterion=nn.L1Loss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
|
||||
x = np.array([1])
|
||||
y = np.array([model.fc0.weight.tolist()[0][0]])
|
||||
|
||||
# 绘制初始的折线图
|
||||
fig, ax = plt.subplots()
|
||||
line, = ax.plot(x, y, '-o')
|
||||
|
||||
|
||||
|
||||
# 显示初始的折线图
|
||||
plt.show(block=False)
|
||||
|
||||
# scatter = ax.scatter(x, y)
|
||||
# # 定义标注文本
|
||||
# def formatter(**kwargs):
|
||||
# x, y = kwargs['x'], kwargs['y']
|
||||
# return f"x={x:.2f}\ny={y:.2f}"
|
||||
# # 配置悬浮窗
|
||||
# cursor = mpldatacursor.datacursor(
|
||||
# scatter,
|
||||
# formatter=formatter,
|
||||
# display='multiple',
|
||||
# draggable=True,
|
||||
# bbox=dict(fc='white', alpha=0.9),
|
||||
# arrowprops=dict(arrowstyle='->', connectionstyle='arc3'),
|
||||
# )
|
||||
|
||||
# 定义自定义事件类型
|
||||
class AddPointEvent:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
# 定义事件处理函数,用于处理自定义事件
|
||||
def on_add_point(event):
|
||||
# 添加新的点
|
||||
print(f"x:{event.x},y:{event.y}")
|
||||
new_x = np.array([event.x])
|
||||
new_y = np.array([event.y])
|
||||
|
||||
# 使用set_xdata()和set_ydata()方法更新折线图的数据
|
||||
line.set_xdata(np.append(line.get_xdata(), new_x))
|
||||
line.set_ydata(np.append(line.get_ydata(), new_y))
|
||||
|
||||
# 更新折线图的坐标轴范围
|
||||
ax.relim()
|
||||
ax.autoscale_view()
|
||||
ax.annotate("hi", xy=(event.x, event.y), xytext=(event.x, event.y+0.5), ha='center', va='bottom')
|
||||
|
||||
# 更新折线图
|
||||
fig.canvas.draw()
|
||||
|
||||
# 连接自定义事件到事件处理函数上
|
||||
cid = fig.canvas.mpl_connect(AddPointEvent, on_add_point)
|
||||
|
||||
# 定义计时器回调函数,用于触发自定义事件
|
||||
def timer_callback():
|
||||
# 生成一个随机的点,并触发自定义事件
|
||||
global time_point
|
||||
time_point=time_point+1
|
||||
y=np.random.uniform(low=50, high=101)
|
||||
y_hat = y*np.random.uniform(low=low, high=high)
|
||||
batch_sequences = np.array([[y,y_hat]])
|
||||
# print(batch_sequences.shape)
|
||||
inputs = torch.tensor(batch_sequences[:, :-1], dtype=torch.float32)
|
||||
|
||||
targets = torch.tensor(batch_sequences[:,-1], dtype=torch.float32)
|
||||
|
||||
|
||||
# 前向传播
|
||||
outputs = model(inputs)
|
||||
# print(f"targets {targets}")
|
||||
|
||||
# print(f"input:{inputs}")
|
||||
|
||||
# print(f"outputs {outputs}")
|
||||
|
||||
outputs=outputs.reshape(targets.shape)
|
||||
# print(outputs.shape)
|
||||
# print(f"output:{outputs.shape}")
|
||||
# print(f"target:{targets.shape}")
|
||||
loss = criterion(outputs, targets)
|
||||
# print(f"loss:{loss}")
|
||||
# loss=loss*1000
|
||||
# 反向传播和优化
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
new_y = model.fc0.weight.tolist()[0][0]
|
||||
|
||||
|
||||
new_event = AddPointEvent(time_point, new_y)
|
||||
fig.canvas.callbacks.process(AddPointEvent, new_event)
|
||||
|
||||
# 重新启动计时器
|
||||
timer = fig.canvas.new_timer(interval=1000)
|
||||
timer.add_callback(timer_callback)
|
||||
timer.start()
|
||||
|
||||
# 启动计时器
|
||||
timer = fig.canvas.new_timer(interval=1000)
|
||||
timer.add_callback(timer_callback)
|
||||
timer.start()
|
||||
|
||||
|
||||
def on_hover(event):
|
||||
# 检查是否有Artist在鼠标位置
|
||||
if event.inaxes == ax:
|
||||
# 获取鼠标位置
|
||||
x, y = event.xdata, event.ydata
|
||||
# 显示自定义信息
|
||||
ax.annotate(f'({x:.2f}, {y:.2f})', (x, y))
|
||||
|
||||
# 将on_hover函数绑定到Figure的鼠标移动事件上
|
||||
# fig.canvas.mpl_connect('motion_notify_event', on_hover)
|
||||
|
||||
# 显示动态更新的折线图,并保持程序运行
|
||||
plt.show()
|
||||
|
Binary file not shown.
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
|
||||
input_size = 1
|
||||
hidden_size = 1
|
||||
num_layers = 2
|
||||
output_size = 1
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_size=1, hidden_size=1, num_layers=2, output_size=1):
|
||||
super(MLP, self).__init__()
|
||||
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||
# self.lstm=nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
|
||||
self.fc0=nn.Linear(input_size,hidden_size,bias=False)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc = nn.Linear(hidden_size, output_size)
|
||||
|
||||
nn.init.constant_(self.fc0.weight,0.78)
|
||||
nn.init.xavier_uniform_(self.fc.weight)
|
||||
nn.init.zeros_(self.fc.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# out, _ = self.lstm(x)
|
||||
out0=self.fc0(x)
|
||||
# out1=self.relu(out0)
|
||||
# print(f"out{out[0]}")
|
||||
# print(f"out{out[-1]}")
|
||||
# out2 = self.fc(out1)
|
||||
return out0
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue