【Python深度学习系列】网格搜索神经网络超参数:权重初始化方法(案例+源码)

这是我的第262篇原创文章。

一、引言

图片

在深度学习中,超参数是指在训练模型时需要手动设置的参数,它们通常不能通过训练数据自动学习得到。超参数的选择对于模型的性能至关重要,因此在进行深度学习实验时,超参数调优通常是一个重要的步骤。常见的超参数包括:

  • model.add()

    • neurons(隐含层神经元数量)

    • init_mode(初始权重方法)

    • activation(激活函数)

    • dropout(丢弃率)

  • model.compile()

    • loss(损失函数)

    • optimizer(优化器)

      • learning rate(学习率)

      • momentum(动量)

      • weight decay(权重衰减系数)

  • model.fit()

    • batch size(批量大小)

    • epochs(迭代次数)

一般来说,可以通过手动调优、网格搜索(Grid Search)、随机搜索(Random Search)、自动调参算法方式进行超参数调优,本文采用网格搜索选择神经网络权重初始化方法。

二、实现过程

2.1 准备数据

dataset:

dataset = pd.read_csv("data.csv", header=None)
dataset = pd.DataFrame(dataset)
print(dataset)

图片

2.2 数据划分

# 切分数据为输入 X 和输出 Y
X = dataset.iloc[:,0:8]
Y = dataset.iloc[:,8]
# 为了复现,设置随机种子
seed = 7
np.random.seed(seed)
random.set_seed(seed)

2.3 创建模型

需要定义个网格的架构函数create_model,create_model里面的参数要在KerasClassifier这个对象里面存在而且参数名要一致。

def create_model(neurons_1):
    # 创建模型
    model = Sequential()
    model.add(Dense(neurons_1, input_shape=(8, ), kernel_initializer='uniform', activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))

    # 编译模型
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

model = KerasClassifier(model=create_model, epochs=100, batch_size=80, verbose=0, init_mode='uniform')
这里使用了scikeras库的KerasClassifier类来定义一个分类器,这里由于KerasClassifier没有定义初

始化权重的参数,需要自定义一个表示初始化权重的参数init_mode,并赋默认值为'uniform'。

2.4 定义网格搜索参数

param_grid = {'init_mode': ['uniform', 'lecun_uniform', 'normal', 'zero', 'glorot_normal','glorot_uniform', 'he_normal', 'he_uniform']}

param_grid是一个字典,key是超参数名称,这里的名称必须要在KerasClassifier这个对象里面存在而且参数名要一致。value是key可取的值,也就是要尝试的方案。

2.5 进行参数搜索

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(estimator=model,  param_grid=param_grid)
grid_result = grid.fit(X, Y)

使用sklearn里面的GridSearchCV类进行参数搜索,传入模型和网格参数。

2.6 总结搜索结果

print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))

结果:

图片

经过网格搜索,各层权重初始化最优的方法是normal。

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/553692.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024全新快递平台系统独立版小程序源码|带cps推广营销流量主+前端

本文来自:2024全新快递平台系统独立版小程序源码|带cps推广营销流量主前端 - 源码1688​​​​​ 应用介绍 快递代发快递代寄寄件小程序可以对接易达云洋一级总代快递小程序,接入云洋/易达物流接口,支持选择快递公司,三通一达&am…

【leetcode面试经典150题】57. 环形链表(C++)

【leetcode面试经典150题】专栏系列将为准备暑期实习生以及秋招的同学们提高在面试时的经典面试算法题的思路和想法。本专栏将以一题多解和精简算法思路为主,题解使用C语言。(若有使用其他语言的同学也可了解题解思路,本质上语法内容一致&…

电动车违停智能监测摄像机

电动车的普及带来了便利,但也衍生了一些问题,其中最常见的之一就是电动车的违停。电动车的违停不仅会影响交通秩序,还可能对周围环境和行人安全造成影响。为了监测和管理电动车的违停情况,可以使用电动车违停智能监测摄像机。这种…

退市危机袭来,环保行业能否逆境崛起?|中联环保圈

近年来,环保行业风波持续不断,众多环保大公司风险频出。博天环境的退市危机令人感慨,深圳星源因涉嫌信息披露违法违规而被警告退市,更是引发业界震动。 最近三年,证监会办理的上市公司信息披露违法案件多达 397 件&…

Linux内核之virt_to_page实现与用法实例(五十)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

Python 使用 pip 安装 matplotlib 模块(精华版)

pip 安装 matplotlib 模块 1.使用pip安装matplotlib(五步实现):2.使用下载的matplotlib画图: 1.使用pip安装matplotlib(五步实现): 长话短说:本人下载 matplotlib 花了大概三个半小时屡屡碰壁,险些暴走。为了不让新来的小伙伴走我的弯路,特意…

IPAguard--iOS代码混淆工具(免费)

IPAguard是一款为iOS开发者设计的代码混淆工具,旨在为开发者提供方便制作和分析马甲包的解决方案。通过高效的匹配算法,IPAguard可以在保证代码混淆的同时,保证编译后的代码质量,减少了因混淆引起的bug,使得开发者能够…

写后端项目的分页查询时,解决分页不更新

写基于VueSpringBoot项目,实现分页查询功能时,改完代码后,发现页数不更新: 更改处如下: 显示如图: 发现页数没有变化,两条数据还是显示在同一页,而且每页都10条。且重启项目也没有更…

代码随想录算法训练营第一天 | 704. 二分查找 | 27. 移除元素

704. 二分查找 int search(int* nums, int numsSize, int target) {int left 0, right numsSize, mid;while (left < right) {mid left (right -left) / 2;if (nums[mid] < target) {left mid 1;} else if (nums[mid] > target) {right mid;} else {return mid…

民兵档案管理系统-退伍军人档案管理全流程追踪

民兵档案管理系统&#xff08;智档案DW-S403&#xff09;是依托互3D技术、云计算、大数据、RFID技术、数据库技术、AI、视频分析技术对RFID智能仓库进行统一管理、分析的信息化、智能化、规范化的系统。 RFID档案管理系统是以先进的RFID技术为基础&#xff0c;结合数据库技术、…

压缩感知的概述梳理(1)

参考文献 An efficient visually meaningful image compression and encryption scheme based on compressive sensing and dynamic LSB embedding 基本内容 基本关系梳理 压缩感知核心元素 信号 x 长度&#xff1a;N动态稀疏或可用变换表示&#xff1a;x &#x1d74d;s …

AI实践与学习4_大模型之检索增强生成RAG实践

背景 针对AI解题业务场景&#xff0c;靠着ToT、CoT等提示词规则去引导模型的输出答案&#xff0c;一定程度相比Zero-shot解答质量更高&#xff08;正确率、格式&#xff09;等。但是针对某些测试CASE&#xff0c;LLM仍然不能输出期望的正确结果&#xff0c;将AI解题应用生产仍…

「不羁联盟/XDefiant」4月20号开启服务器测试,游戏预下载安装教程

XDefiant》开启Alpha测试&#xff0c;这是一款免费游玩的快节奏 FPS 竞技游戏&#xff0c;可选择特色阵营&#xff0c;搭配个性化的装备&#xff0c;体验 6v6 对抗或是线性游戏模式。高品质射击竞技端游XDefiant以6v6双边对抗为核心&#xff0c;对局模式分为区域与线性两大类&a…

LeetCode108:讲有序数组转换为平衡二叉搜索树

题目描述 给你一个整数数组 nums &#xff0c;其中元素已经按 升序 排列&#xff0c;请你将其转换为一棵 平衡二叉搜索树。 代码 class Solution { public:TreeNode* traversal(vector<int>& nums, int left, int right) {if (left > right) return nullptr;int …

单片机学习笔记——LED点阵

代码如下&#xff0c;注意管脚和扫描所用的hc595_write_data函数 #include "reg51.h"typedef unsigned int u16; //对系统默认数据类型进行重定义 typedef unsigned char u8;//定义74HC595控制管脚 sbit SRCLKP3^6; //移位寄存器时钟输入 sbit RCLKP3^5; //存储寄存…

Java | Leetcode Java题解之第36题有效的数独

题目&#xff1a; 题解&#xff1a; class Solution {public boolean isValidSudoku(char[][] board) {int[][] rows new int[9][9];int[][] columns new int[9][9];int[][][] subboxes new int[3][3][9];for (int i 0; i < 9; i) {for (int j 0; j < 9; j) {char …

内网代理技术总结

代理技术就是解决外网和内网的通信问题&#xff0c;例如&#xff0c;我的一个外网主机想要找到另外一个网段下的一个内网主机&#xff0c;理论上是无法找到的。如果我们想要进行通信的话就要使用代理技术。我们可以找到一个与目标内网主机在容易网段下可以通信的外网主机&#…

Android 12 如何加载 native 原生库

在 Android 7.0 及更高版本中&#xff0c;系统库与应用库是分开的。 图1. 原生库的命名空间 原生库的命名空间可防止应用使用私有平台的原生 API&#xff08;例如使用 OpenSSL&#xff09;。该命名空间还可以避免应用意外使用平台库&#xff08;而非它们自己的库&#xff09;的…

LangChain入门:22.使用 arXiv 工具开发科研助理

有一些工具&#xff0c;比如 SerpAPI&#xff0c;你已经用过了&#xff0c;这里我们再来用一下 arXiv 工具。arXiv 本身就是一个论文研究的利器&#xff0c;里面的论文数量比 AI 顶会还早、还多、还全。那么把它以工具的形式集成到 LangChain 中&#xff0c;能让你在研究学术最…

高精度PWM脉宽调制信号转模拟信号隔离变送器1Hz-10KHz转0-5V/0-10V/1-5V,0-10mA/0-20mA/4-20mA

主要特性: >>精度等级&#xff1a;0.1级。产品出厂前已检验校正&#xff0c;用户可以直接使用 >>辅助电源&#xff1a;8-32V 宽范围供电 >>PWM脉宽调制信号输入: 1Hz~10KHz >>输出标准信号&#xff1a;0-5V/0-10V/1-5V,0-10mA/0-20mA/4-20mA等&…
最新文章