Meta-Learning数学原理

news/2024/10/11 17:07:17/文章来源:https://blog.csdn.net/qq_60735796/article/details/142342201

文章目录

  • 什么是元学习
  • 元学习的目标
  • 元学习的类型
  • 数学推导
    • 1. 传统机器学习的数学表述
    • 2. 元学习的基本思想
    • 3. MAML 算法推导
      • 3.1 元任务设置
      • 3.2 内层优化:任务级别学习
      • 3.3 外层优化:元级别学习
      • 3.4 元梯度计算
      • 3.5 最终更新规则
    • 4. 算法合并
    • 5. 理解 MAML 的优化
  • 图例
  • MAML 的优势
  • 其他元学习方法
  • 总结
  • 手写笔记

🍃作者介绍:双非本科大四网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发,目前开始人工智能领域相关知识的学习
🦅个人主页:@逐梦苍穹
📕所属专栏:人工智能
🌻gitee地址:xzl的人工智能代码仓库
✈ 您的一键三连,是我创作的最大动力🌹

之前介绍过元学习的内容:https://xzl-tech.blog.csdn.net/article/details/142025393
这篇文章讲一下Meta-Learning的数学原理。

什么是元学习

元学习(Meta-Learning),也称为“学习如何学习”,是一种机器学习方法,其目的是通过学习算法的经验和结构特性,提升算法在新任务上的学习效率。

换句话说,元学习试图学习一种更有效的学习方法,使得模型能够快速适应新的任务或环境。


传统的机器学习算法通常需要大量的数据来训练模型,并且当数据分布发生变化或者遇到一个新任务时,模型往往需要重新训练才能保持良好的性能。

而元学习则不同,它通过 从多个相关任务中学习,从而在面对新任务时更快速地进行学习。

元学习的核心思想是利用“学习的经验”来提高学习的速度和质量。

在元学习的框架中,有两个层次的学习过程:

  1. 元学习者(Meta-Learner): 负责从多个任务中提取经验和知识,用于更新学习策略或模型参数。
  2. 基础学习者(Base Learner): 在每个具体任务上执行实际的学习过程。

元学习的目标

元学习的目标是解决以下问题:

  • 快速适应: 当模型面临新任务时,能够基于已有的经验快速适应,而无需大量的数据和计算资源。
  • 跨任务泛化: 提高模型从多个任务中学习到的知识在新任务上的泛化能力。
  • 提高数据效率: 减少模型在新任务上所需的数据量,尤其是在数据稀缺或高昂的情况下。

元学习的类型

元学习可以按照不同的方式分类,以下是三种主要类型:

  1. 基于模型的元学习(Model-Based Meta-Learning):
    • 这种方法通过直接设计一种能够快速适应新任务的模型架构,通常是通过某种特殊的神经网络结构来实现的。例如,基于记忆的神经网络(如 LSTM 或 Memory-Augmented Neural Networks)被设计成能有效地记住过去的任务信息,并在新任务上进行快速调整。
    • 例子: MANN(Memory-Augmented Neural Networks),SNAIL(Simple Neural Attentive Meta-Learner)。
  2. 基于优化的元学习(Optimization-Based Meta-Learning):
    • 这种方法的核心是通过改进优化过程本身来实现快速学习。其代表算法是 MAML(Model-Agnostic Meta-Learning),它通过在所有任务上共享一个初始模型参数,使得初始模型在每个任务上进行少量梯度下降更新后能够快速适应新任务。
    • 例子: MAML(Model-Agnostic Meta-Learning),Reptile。
  3. 基于记忆的元学习(Memory-Based Meta-Learning):
    • 这类方法直接存储并检索训练过程中的经验数据。当遇到新任务时,通过查找与之相似的旧任务,并利用这些旧任务的数据和经验来快速学习。k-NN(k-近邻)方法是最基本的例子,而更复杂的方法可能使用深度记忆网络。
    • 例子: Meta Networks,Prototypical Networks。

数学推导

1. 传统机器学习的数学表述

在传统的机器学习中,我们通常试图找到一个函数 f θ f_\theta fθ来最小化给定数据集 D D D的损失函数:
θ ∗ = arg ⁡ min ⁡ θ L ( f θ , D ) \theta^* = \arg\min_{\theta} L(f_\theta, D) θ=argminθL(fθ,D)
其中:

  • θ \theta θ是模型的参数。
  • L ( f θ , D ) L(f_\theta, D) L(fθ,D)是损失函数,例如交叉熵损失。
  • 通过梯度下降等优化方法,我们不断更新参数 θ \theta θ以最小化损失。

2. 元学习的基本思想

元学习的目标是找到一种元算法 F ϕ F_\phi Fϕ,使得它可以快速学习新任务。这里的关键是学习一种 学习算法。换句话说,元学习希望找到一组元参数 ϕ \phi ϕ,从而在给定一个新任务 T i T_i Ti时,使用少量数据和梯度更新就可以迅速找到特定任务的参数 θ i \theta_i θi

3. MAML 算法推导

MAML 的目标是学习一个初始模型参数 θ \theta θ,使得它可以通过少量的梯度更新快速适应新任务。

3.1 元任务设置

假设有一组任务 { T 1 , T 2 , … , T N } \{T_1, T_2, \dots, T_N\} {T1,T2,,TN},每个任务 T i T_i Ti有自己的训练数据 D i train D_i^{\text{train}} Ditrain和测试数据 D i test D_i^{\text{test}} Ditest

3.2 内层优化:任务级别学习

对于每个任务 T i T_i Ti,我们首先使用任务的训练数据 D i train D_i^{\text{train}} Ditrain和当前的模型参数 θ \theta θ进行一次或多次梯度更新,得到任务特定的参数 θ i ′ \theta_i' θi
θ i ′ = θ − α ∇ θ L T i ( f θ , D i train ) \theta_i' = \theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θi=θαθLTi(fθ,Ditrain)
其中:

  • α \alpha α是学习率。
  • L T i ( f θ , D i train ) L_{T_i}(f_\theta, D_i^{\text{train}}) LTi(fθ,Ditrain)是任务 T i T_i Ti的损失函数,例如对于分类任务可以是交叉熵损失。

3.3 外层优化:元级别学习

在每个任务的测试数据上评估更新后的模型参数 θ i ′ \theta_i' θi,计算其损失,并在所有任务上最小化测试损失的总和:
min ⁡ θ ∑ i = 1 N L T i ( f θ i ′ , D i test ) \min_{\theta} \sum_{i=1}^N L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) minθi=1NLTi(fθi,Ditest)
θ i ′ \theta_i' θi展开,这个目标实际上是关于初始参数 θ \theta θ的优化问题:
min ⁡ θ ∑ i = 1 N L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \min_{\theta} \sum_{i=1}^N L_{T_i}(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}) minθi=1NLTi(fθαθLTi(fθ,Ditrain),Ditest)

3.4 元梯度计算

为了优化这个目标,我们需要对 θ \theta θ求梯度。这里涉及二阶梯度,因为 θ i ′ \theta_i' θi是通过内层优化得到的:
θ ← θ − β ∑ i = 1 N ∇ θ L T i ( f θ i ′ , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θθβi=1NθLTi(fθi,Ditest)
其中 β \beta β是元学习的学习率。

  • 这个更新包含了二阶导数项: ∇ θ θ i ′ = ∇ θ ( θ − α ∇ θ L T i ( f θ , D i train ) ) \nabla_\theta \theta_i' = \nabla_\theta \left(\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})\right) θθi=θ(θαθLTi(fθ,Ditrain))

3.5 最终更新规则

最终的元学习更新规则可以写为:
θ ← θ − β ∑ i = 1 N ∇ θ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_\theta L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) θθβi=1NθLTi(fθαθLTi(fθ,Ditrain),Ditest)

4. 算法合并

将内层优化 θ i ′ \theta_i' θi代入外层优化的公式中,外层优化的梯度 ∇ θ L T i ( f θ i ′ , D i test ) \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θLTi(fθi,Ditest)需要应用链式法则:
∇ θ L T i ( f θ i ′ , D i test ) = ∇ θ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) = \nabla_\theta L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) θLTi(fθi,Ditest)=θLTi(fθαθLTi(fθ,Ditrain),Ditest)
通过链式法则,展开这个公式:
∇ θ L T i ( f θ i ′ , D i test ) = ∇ θ i ′ L T i ( f θ i ′ , D i test ) ⋅ ∇ θ θ i ′ \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) = \nabla_{\theta_i'} L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) \cdot \nabla_\theta \theta_i' θLTi(fθi,Ditest)=θiLTi(fθi,Ditest)θθi
其中 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi的形式为:
∇ θ θ i ′ = I − α ∇ θ 2 L T i ( f θ , D i train ) \nabla_\theta \theta_i' = I - \alpha \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θθi=Iαθ2LTi(fθ,Ditrain)
I I I是单位矩阵, ∇ θ 2 L T i ( f θ , D i train ) \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θ2LTi(fθ,Ditrain)是损失函数关于 θ \theta θ的二阶导数(Hessian 矩阵)。


最终的公式:

将这些部分合并在一起,得到 MAML 的最终更新公式为:
θ ← θ − β ∑ i = 1 N ∇ θ i ′ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) ⋅ ( I − α ∇ θ 2 L T i ( f θ , D i train ) ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_{\theta_i'} L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) \cdot \left(I - \alpha \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}})\right) θθβi=1NθiLTi(fθαθLTi(fθ,Ditrain),Ditest)(Iαθ2LTi(fθ,Ditrain))


解释:

  • 内层优化:第一部分 θ i ′ = θ − α ∇ θ L T i ( f θ , D i train ) \theta_i' = \theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θi=θαθLTi(fθ,Ditrain)表示在每个任务上用梯度下降更新 θ \theta θ,得到特定于任务的参数 θ i ′ \theta_i' θi
  • 外层优化:外层优化考虑测试集上的损失,并通过链式法则计算对 θ \theta θ的梯度。这部分的关键是包含了内层更新的二阶导数 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi
  • 合并公式:最终的更新公式同时结合了内层和外层优化的过程,充分考虑了内层更新对外层优化的影响。

简化(在某些情况下):

在实际应用中,计算二阶导数(Hessian 矩阵)非常昂贵。因此,有时会使用近似方法来简化计算,例如“一次近似 MAML (First-Order MAML, FOMAML)”,忽略二阶项,仅使用一阶导数进行更新。简化后的更新公式为:
θ ← θ − β ∑ i = 1 N ∇ θ i ′ L T i ( f θ i ′ , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_{\theta_i'} L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θθβi=1NθiLTi(fθi,Ditest)

这个简化版本去除了 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi中的二阶导数计算。

5. 理解 MAML 的优化

通过上面的推导,MAML 的优化分为两个阶段:

  1. 内层优化:在每个任务上利用任务的训练数据对模型进行一次或多次更新,以获得任务特定的模型参数。
  2. 外层优化:在所有任务的测试数据上评估内层优化后的模型,并利用这个评估结果更新模型的初始参数。

图例

MAML 的优势

MAML 的一个关键优势在于,它学习了一个初始参数 θ \theta θ,使得它可以通过少量梯度更新快速适应新任务。这使得它非常适合少样本学习场景,如几次样本分类。

其他元学习方法

除了 MAML,文件中还提到其他元学习方法,如基于优化器的元学习、网络架构搜索(NAS)等。这些方法都在不同程度上优化了元学习的过程,使得模型能够在少量数据的情况下快速学习。

总结

元学习的数学推导核心在于通过多个任务的训练,学习到一个通用的学习算法(或模型初始化),使得模型可以快速适应新任务。MAML 是元学习的一个经典方法,通过在元任务上进行二阶优化,使模型获得更好的泛化能力。

手写笔记

最后放几张今天的手写笔记,主要是方便查阅。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ldbm.cn/p/443781.html

如若内容造成侵权/违法违规/事实不符,请联系编程新知网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kubernetes 集群内 DNS

DNS 简介 在互联网早期,随着连接设备数量的增加,IP 地址的管理与记忆变得越来越复杂。为了简化网络资源的访问,DNS(Domain Name System)应运而生。DNS 的核心作用是将用户可读的域名(如 www.example.com&a…

C++伟大发明--模版

C起初是不受外界关注的,别人觉得他和C语言没有本质上的区别,只是方便些,直到祖师爷发明了模版,开始和C语言有了根本的区别。 我们通过一个小小的例子来搞清楚什么是模版,模版的作用到底有多大,平时我们想要…

Flask-JWT-Extended登录验证

1. 介绍 """安装:pip install Flask-JWT-Extended创建对象 初始化与app绑定jwt JWTManager(app) # 初始化JWTManager设置 Cookie 的选项:除了设置 cookie 的名称和值之外,你还可以指定其他的选项,例如:过期时间 (max_age)&…

【楚怡杯】职业院校技能大赛 “云计算应用” 赛项样题四

某企业根据自身业务需求,实施数字化转型,规划和建设数字化平台,平台聚焦“DevOps开发运维一体化”和“数据驱动产品开发”,拟采用开源OpenStack搭建企业内部私有云平台,开源Kubernetes搭建云原生服务平台,选…

简单题21 - 合并两个有序链表(Java)20240917

问题描述: java代码: /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}* ListNode(int val) { this.val val; }* ListNode(int val, ListNode next) { this.val val…

C语言 ——— 编写函数,判断一个整数是否是回文整数

目录 题目要求 代码实现 题目要求 编写一个函数,用来判断一个整数是否是回文整数,如果是回文整数就返回 true ,如果不是就返回 false 举例说明: 输入:121 输出:true 输入:1321 输出&#xf…

VBS学习1 - 语法、内置函数、内置对象

文章目录 概述执行脚本语法转义字符文本弹框msgbx定义变量dim(普通类型)定义接收对象set字符拼接&用户自定义输入框inputbox以及输入判断ifelse数组(参数表最大索引,非数组容量)有容量无元素基于元素确定容量 循环…

828华为云征文|部署在线文件管理器 Spacedrive

828华为云征文|部署在线文件管理器 Spacedrive 一、Flexus云服务器X实例介绍1.1 云服务器介绍1.2 产品优势1.3 计费模式 二、Flexus云服务器X实例配置2.1 重置密码2.2 服务器连接2.3 安全组配置 三、部署 Spacedrive3.1 Spacedrive 介绍3.2 Docker 环境搭建3.3 Spac…

反转字符串 II--力扣541

反转字符串 II 题目思路代码 题目 思路 本题的关键在于理解每隔 2k 个字符的前 k 个字符进行反转,剩余字符小于 2k 但大于或等于 k 个,则反转前 k 个字符。并且剩余字符少于 k 个,则将剩余字符全部反转。 让i每次跳2k,成为每一次…

考软考的信息安全工程师,有什么诀窍在一个月内通过吗?

一般是至少是2个月时间拿来备考的,低于2个月的话,时间肯定是比较赶的。虽然一个月时间相对紧张,但通过合理规划和高效利用时间,也是有可能成功通过考试的。以下是一份详细的备考策略,旨在帮助大家在有限的时间内最大化…

AI客服机器人开启企业客户服务新纪元

随着人工智能(AI)技术的迅猛发展,使得AI客服机器人走进了我们的视野,成为提高客户满意度和业务效率的不二法宝。这些智能机器人不仅能够处理海量信息,还能为客户提供个性化的服务体验。 一、AI客服机器人的基本原理 AI客服机器人是基于人工智…

TCP并发服务器的实现

一请求一线程 问题 当客户端数量较多时,使用单独线程为每个客户端处理请求可能导致系统资源的消耗过大和性能瓶颈。 资源消耗: 线程创建和管理开销:每个线程都有其创建和销毁的开销,特别是在高并发环境中,这种开销…

MySQL高阶1777-每家商店的产品价格

题目 找出每种产品在各个商店中的价格。 可以以 任何顺序 输出结果。 准备数据 create database csdn; use csdn;Create table If Not Exists Products (product_id int, store ENUM(store1, store2, store3), price int); Truncate table Products; insert into Products …

ZYNQ FPGA自学笔记~点亮LED

一 ZYNQ FPGA简介 ZYNQ FPGA主要特点是包含了完整的ARM处理系统,内部包含了内存控制器和大量的外设,且可独立于可编程逻辑单元,下图中的ARM内核为 ARM Cortex™-A9,ZYNQ FPGA包含两大功能块,处理系统Processing System…

十款主流的供应链管理系统盘点,优缺点一目了然!

本文将盘点十款供应链管理系统,为企业选型提供参考! 想象一下,一家企业在生产和销售产品的过程中,原材料供应不及时、库存积压严重、物流配送混乱。这时,供应链管理系统就如同一位高效的指挥家,将各个环节紧…

Hazel 2024

不喜欢游戏的人也可以做引擎,比如 cherno 引擎的作用主要是有两点: 将数据可视化交互 当然有些引擎的功能也包含有制作数据文件,称之为资产 assets 不做窗口类的应用栈,可能要花一年才能做一个能实际使用的应用,只需…

web知识

sql注入的万能密码:1’ or true#如果页面没有什么东西可见,首先可以用diresearch看看有没有什么隐藏的目录,或者检查源代码,如果这些都没成功可以用 dirsearch如果没有找到东西,可能需要调低线程 dirsearch.py -u url -e * --ti…

OpenHarmony(鸿蒙南向开发)——标准系统方案之瑞芯微RK3568移植案例(上)

往期知识点记录: 鸿蒙(HarmonyOS)应用层开发(北向)知识点汇总 鸿蒙(OpenHarmony)南向开发保姆级知识点汇总~ OpenHarmony(鸿蒙南向开发)——轻量和小型系统三方库移植指南…

模型部署基础

神经网络的模型部署是将训练好的神经网络模型应用到实际系统中,以实现预测、分类、推荐等任务的过程。下图展示了模型从训练到部署的整个流程: 1.模型部署的平台 在线服务器端部署 在线服务器端部署适用于处理大模型、需要精度优先的应用场景&#xff…

[C++进阶[六]]list的相关接口模拟实现

1.前言 本章重点 在list模拟实现的过程中&#xff0c;主要是感受list的迭代器的相关实现&#xff0c;这是本节的重点和难点。 2.list接口的大致框架 list是一个双向循环链表&#xff0c;所以在实现list之前&#xff0c;要先构建一个节点类 template <class T> struct L…