keras建模的3种方式详解
目录
- keras建模的3种方式
- 第一种序列模型:
- 第二种函数模型
- 第三种子类模型
- 常用的损失函数:
keras建模的3种方式
keras是google公司2016年发布的tensorflow为后端的深度学习网络的高级接口。
三种建模方式:
- 序列模型
- 函数模型
- 子类模型
第一种序列模型:
import numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Sequential from keras.models import load_model from keras.layers import Dense #加载数据 def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #序列模型 def DNN(train_x,train_y,valid_x,valid_y): #創建模型 model=Sequential() model.add(Dense(64,input_dim=784,activation='relu')) model.add(Dense(128,activation='relu')) model.add(Dense(10,activation='softmax')) #查看网络模型 model.summary() #编译模型 model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) #训练模型 model.fit(train_x,train_y,BATch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #保存模型 model.save('sequential.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') DNN编程客栈(train_x,train_y,valid_x,valid_y) model=load_model('sequential.h5') #下载模型 pre=model.predict(test_x) #测试验证 #计算验证集精度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
第二种函数模型
import numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Model from keras.models import编程 load_model from keras.layers import Input,Dense #加载数据 def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #函数模型 def DNN(train_x,train_y,valid_x,valid_y): #创建模型 inputs=Input(shape=(784,)) x=Dense(64,activation='relu')(inputs) x=Dense(128,activation='relu')(x) output=Dense(10,activation='softmax')(x) model=Model(input=inputs,output=output) #查看网络结构 model.summary() #编译模型 model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) #训练模型 model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #保存模型 model.save('fun_model.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') DNN(train_x,train_y,valid_x,valid_y) model=load_model('fun_model.h5') #下载模型 pre=model.predict(test_x) #验证数据js集 #验证数据集准确度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
第三种子类模型
imphttp://www.devze.comort numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Model from keras.layers import Dense #加载数据 def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #子类模型 class DNN(Model): def __init__(self,train_x,train_y,valid_x,valid_y): super(DNN,self).__init__() #初始化网络模型 self.dense1=Dense(64,input_dim=784,activation='relu') self.dense2=Dense(128,activation='relu') self.dense3=Dense(10,activation='softmax') def call(self,inputs): #回调順序 x=self.dense1(inputs) x=self.dense2(x) x=self.dense3(x) return x train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') model=DNN(train_x,train_y,valid_x,valid_y) #编译模型(学习率、损失函数、模型评估) model.compile(optimizer='adam(lr=0.001)',loss='categorical_crossentropy',metrics=['accuracy']) #训练模型 model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #查看网络结构 model.summary() pre=model.predict(test_x) #验证数据集编程客栈 #计算验证数据集的准确度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
常用的损失函数:
mse #均方差(回归)
mae #绝对误差(回归)
binary_crossentropy #二值交叉熵(二分类,逻辑回归)
categorical_crossentropy #交叉熵(多分类)
到此这篇关于keras建模的3种方式详解的文章就介绍到这了,更多相关keras建模方式内容请搜索编程客栈(www.devze.com)以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程客栈(www.devze.com)!
精彩评论