gwx
parent
77c72f7cd9
commit
55e1ac7653
@ -1 +1,3 @@
|
|||||||
{}
|
{
|
||||||
|
"attachmentFolderPath": "./"
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,142 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from models import MLP
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import mpldatacursor
|
||||||
|
|
||||||
|
time_point = 1
|
||||||
|
low=0.82
|
||||||
|
high=0.83
|
||||||
|
model = MLP.MLP()
|
||||||
|
model.load_state_dict(torch.load('lstm_model.pt'))
|
||||||
|
# 创建初始的折线图数据
|
||||||
|
learning_rate = 0.01
|
||||||
|
# learning_rate=1
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
# criterion=nn.L1Loss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
|
||||||
|
x = np.array([1])
|
||||||
|
y = np.array([model.fc0.weight.tolist()[0][0]])
|
||||||
|
|
||||||
|
# 绘制初始的折线图
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
line, = ax.plot(x, y, '-o')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 显示初始的折线图
|
||||||
|
plt.show(block=False)
|
||||||
|
|
||||||
|
# scatter = ax.scatter(x, y)
|
||||||
|
# # 定义标注文本
|
||||||
|
# def formatter(**kwargs):
|
||||||
|
# x, y = kwargs['x'], kwargs['y']
|
||||||
|
# return f"x={x:.2f}\ny={y:.2f}"
|
||||||
|
# # 配置悬浮窗
|
||||||
|
# cursor = mpldatacursor.datacursor(
|
||||||
|
# scatter,
|
||||||
|
# formatter=formatter,
|
||||||
|
# display='multiple',
|
||||||
|
# draggable=True,
|
||||||
|
# bbox=dict(fc='white', alpha=0.9),
|
||||||
|
# arrowprops=dict(arrowstyle='->', connectionstyle='arc3'),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 定义自定义事件类型
|
||||||
|
class AddPointEvent:
|
||||||
|
def __init__(self, x, y):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
|
||||||
|
# 定义事件处理函数,用于处理自定义事件
|
||||||
|
def on_add_point(event):
|
||||||
|
# 添加新的点
|
||||||
|
print(f"x:{event.x},y:{event.y}")
|
||||||
|
new_x = np.array([event.x])
|
||||||
|
new_y = np.array([event.y])
|
||||||
|
|
||||||
|
# 使用set_xdata()和set_ydata()方法更新折线图的数据
|
||||||
|
line.set_xdata(np.append(line.get_xdata(), new_x))
|
||||||
|
line.set_ydata(np.append(line.get_ydata(), new_y))
|
||||||
|
|
||||||
|
# 更新折线图的坐标轴范围
|
||||||
|
ax.relim()
|
||||||
|
ax.autoscale_view()
|
||||||
|
ax.annotate("hi", xy=(event.x, event.y), xytext=(event.x, event.y+0.5), ha='center', va='bottom')
|
||||||
|
|
||||||
|
# 更新折线图
|
||||||
|
fig.canvas.draw()
|
||||||
|
|
||||||
|
# 连接自定义事件到事件处理函数上
|
||||||
|
cid = fig.canvas.mpl_connect(AddPointEvent, on_add_point)
|
||||||
|
|
||||||
|
# 定义计时器回调函数,用于触发自定义事件
|
||||||
|
def timer_callback():
|
||||||
|
# 生成一个随机的点,并触发自定义事件
|
||||||
|
global time_point
|
||||||
|
time_point=time_point+1
|
||||||
|
y=np.random.uniform(low=50, high=101)
|
||||||
|
y_hat = y*np.random.uniform(low=low, high=high)
|
||||||
|
batch_sequences = np.array([[y,y_hat]])
|
||||||
|
# print(batch_sequences.shape)
|
||||||
|
inputs = torch.tensor(batch_sequences[:, :-1], dtype=torch.float32)
|
||||||
|
|
||||||
|
targets = torch.tensor(batch_sequences[:,-1], dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
outputs = model(inputs)
|
||||||
|
# print(f"targets {targets}")
|
||||||
|
|
||||||
|
# print(f"input:{inputs}")
|
||||||
|
|
||||||
|
# print(f"outputs {outputs}")
|
||||||
|
|
||||||
|
outputs=outputs.reshape(targets.shape)
|
||||||
|
# print(outputs.shape)
|
||||||
|
# print(f"output:{outputs.shape}")
|
||||||
|
# print(f"target:{targets.shape}")
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
# print(f"loss:{loss}")
|
||||||
|
# loss=loss*1000
|
||||||
|
# 反向传播和优化
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
new_y = model.fc0.weight.tolist()[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
new_event = AddPointEvent(time_point, new_y)
|
||||||
|
fig.canvas.callbacks.process(AddPointEvent, new_event)
|
||||||
|
|
||||||
|
# 重新启动计时器
|
||||||
|
timer = fig.canvas.new_timer(interval=1000)
|
||||||
|
timer.add_callback(timer_callback)
|
||||||
|
timer.start()
|
||||||
|
|
||||||
|
# 启动计时器
|
||||||
|
timer = fig.canvas.new_timer(interval=1000)
|
||||||
|
timer.add_callback(timer_callback)
|
||||||
|
timer.start()
|
||||||
|
|
||||||
|
|
||||||
|
def on_hover(event):
|
||||||
|
# 检查是否有Artist在鼠标位置
|
||||||
|
if event.inaxes == ax:
|
||||||
|
# 获取鼠标位置
|
||||||
|
x, y = event.xdata, event.ydata
|
||||||
|
# 显示自定义信息
|
||||||
|
ax.annotate(f'({x:.2f}, {y:.2f})', (x, y))
|
||||||
|
|
||||||
|
# 将on_hover函数绑定到Figure的鼠标移动事件上
|
||||||
|
# fig.canvas.mpl_connect('motion_notify_event', on_hover)
|
||||||
|
|
||||||
|
# 显示动态更新的折线图,并保持程序运行
|
||||||
|
plt.show()
|
||||||
|
|
Binary file not shown.
@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
|
|
||||||
|
|
||||||
|
input_size = 1
|
||||||
|
hidden_size = 1
|
||||||
|
num_layers = 2
|
||||||
|
output_size = 1
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, input_size=1, hidden_size=1, num_layers=2, output_size=1):
|
||||||
|
super(MLP, self).__init__()
|
||||||
|
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||||
|
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||||
|
# self.lstm=nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
|
||||||
|
self.fc0=nn.Linear(input_size,hidden_size,bias=False)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.fc = nn.Linear(hidden_size, output_size)
|
||||||
|
|
||||||
|
nn.init.constant_(self.fc0.weight,0.78)
|
||||||
|
nn.init.xavier_uniform_(self.fc.weight)
|
||||||
|
nn.init.zeros_(self.fc.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# out, _ = self.lstm(x)
|
||||||
|
out0=self.fc0(x)
|
||||||
|
# out1=self.relu(out0)
|
||||||
|
# print(f"out{out[0]}")
|
||||||
|
# print(f"out{out[-1]}")
|
||||||
|
# out2 = self.fc(out1)
|
||||||
|
return out0
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue