Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (2024)

文章目录

  • 简介
    • PEFT
    • LORA方法
    • Vision Transformer (ViT)
  • lora方法实战
    • 模型选择
      • google/vit-base-patch16-224-in21k
      • google/vit-base-patch16-224
    • 数据集
    • 模型
    • PEFT configuration and model
    • 训练
    • 预测

简介

PEFT

PEFT(Parameter-Efficient Fine-Tuning)是一个用于高效地将大型预训练模型适配到各种下游应用的库,而无需对模型的所有参数进行微调,因为这在计算上是非常昂贵的。PEFT 方法只微调少量的(额外的)模型参数——显著降低了计算和存储成本——同时其性能与完全微调的模型相当。这使得在消费者硬件上训练和存储大型语言模型(LLMs)变得更加可行。

PEFT 集成了 TransformersDiffusersAccelerate 库,以提供更快、更简单的方法来加载、训练和使用大型模型进行推理。

LORA方法

一种高效训练大型模型的流行方法是在注意力块中插入较小的可训练矩阵,这些矩阵是微调期间要学习的增量权重矩阵的低秩分解。预训练模型的原始权重矩阵被冻结,仅更新较小的矩阵。这减少了可训练参数的数量,降低了内存使用和训练时间,而这些在大型模型中可能非常昂贵。

有几种不同的方法可以将权重矩阵表示为低秩分解,但最常见的方法是低秩适应(LoRA原理)。PEFT 库支持几种其他 LoRA 变体,例如低秩Hadamard积(LoHa)、低秩Kronecker积(LoKr)和自适应低秩适应(AdaLoRA)。你可以在适配器指南中了解这些方法的概念工作原理。如果你有兴趣将这些方法应用于其他任务和用例,比如语义分割、标记分类,可以看看我们的笔记本集合!

Vision Transformer (ViT)

Vision Transformer(ViT)模型是由Alexey Dosovitskiy,Lucas Beyer,Alexander Kolesnikov,Dirk Weissenborn,Xiaohua Zhai,Thomas Unterthiner,Mostafa Dehghani,Matthias Minderer,Georg Heigold,Sylvain Gelly,Jakob Uszkoreit,Neil Houlsby在《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的。这是第一篇成功在ImageNet上训练Transformer编码器并获得非常好结果的论文。

这篇论文的摘要是:

虽然Transformer架构已经成为自然语言处理任务的事实标准,但它在计算机视觉中的应用仍然有限。在视觉领域,注意力要么与卷积网络一起应用,要么用来替换卷积网络的某些组件,同时保持其总体结构不变。我们展示了在这种对CNN的依赖不是必要的,纯Transformer直接应用于图像块序列可以在图像分类任务上表现得非常好。当在大量数据上进行预训练并转移到多个中等规模或小型图像识别基准数据集(ImageNet,CIFAR-100,VTAB等)时,Vision Transformer(ViT)与最先进的卷积网络相比取得了出色的结果,同时训练所需的计算资源大大减少。

具体关于该模型得结构参考:https://huggingface.co/docs/transformers/model_doc/vit

lora方法实战

本指南将向你展示如何快速训练一个图像分类模型——使用低秩分解方法——来识别图像中显示的食物类别。
案例来自官网:https://huggingface.co/docs/peft/task_guides/lora_based_methods

模型选择

google/vit-base-patch16-224-in21k

是一个基于Transformer编码器的模型(类似于BERT),在监督方式下,即ImageNet-21k上以224x224像素的分辨率预训练了大量图像。

图像被呈现给模型作为固定大小的补丁序列(分辨率为16x16),这些补丁被线性嵌入。在序列的开头还添加了一个[CLS]标记,用于分类任务。在将序列馈送到Transformer编码器的层之前,还会添加绝对位置嵌入。

需要注意的是,这个模型不提供任何经过微调的头部,因为这些头部被Google研究人员清零了。但是,模型包括预训练的汇聚层,可以用于下游任务(如图像分类)。

通过预训练模型,它学习了图像的内部表示,然后可以用于提取对下游任务有用的特征:例如,如果您有一个带标签的图像数据集,可以通过在预训练编码器顶部放置一个线性层来训练标准分类器。通常将线性层放置在[CLS]标记的顶部,因为该标记的最后隐藏状态可以被视为整个图像的表示。

from transformers import ViTImageProcessor, FlaxViTModelfrom PIL import Imageimport requestsurl = 'http://images.cocodataset.org/val2017/000000039769.jpg'image = Image.open(requests.get(url, stream=True).raw)processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k')inputs = processor(images=image, return_tensors="np")outputs = model(**inputs)last_hidden_states = outputs.last_hidden_stateprint(last_hidden_states.shape)

不包含分类信息,不包含label信息

google/vit-base-patch16-224

是一个在大规模图像数据集上进行监督预训练的转换器编码器模型(类似于BERT),即ImageNet-21k,分辨率为224x224像素。接下来,该模型在ImageNet上进行微调(也称为ILSVRC2012),这是一个包含100万张图像和1000个类别的数据集,分辨率也为224x224。

图像被呈现给模型作为固定大小补丁(分辨率为16x16)的序列,这些补丁被线性嵌入。还在序列开始处添加了一个[CLS]标记,以用于分类任务。在将序列馈送到Transformer编码器的层之前,还会添加绝对位置嵌入。

通过对模型进行预训练,它学习了图像的内部表示,然后可以用于提取对下游任务有用的特征:例如,如果您有一个带标签的图像数据集,您可以在预训练编码器之上放置一个标准分类器的线性层。通常将线性层放置在[CLS]标记的顶部,因为该标记的最后隐藏状态可以被视为整个图像的表示。

from transformers import AutoImageProcessor, ViTForImageClassificationfrom PIL import Imageimport requests,torchurl = 'http://images.cocodataset.org/val2017/000000039769.jpg'image = Image.open(requests.get(url, stream=True).raw)processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')inputs = processor(images=image, return_tensors="pt")print(inputs)print(inputs["pixel_values"].shape)outputs = model(**inputs)with torch.no_grad(): logits = model(**inputs).logits predicted_label = logits.argmax(-1).item() print(model.config.id2label[predicted_label])

输出:

{'pixel_values': tensor([[[[ 0.1137, 0.1686, 0.1843, ..., -0.1922, -0.1843, -0.1843], [ 0.1373, 0.1686, 0.1843, ..., -0.1922, -0.1922, -0.2078], [ 0.1137, 0.1529, 0.1608, ..., -0.2314, -0.2235, -0.2157], ..., [ 0.8353, 0.7882, 0.7333, ..., 0.7020, 0.6471, 0.6157], [ 0.8275, 0.7961, 0.7725, ..., 0.5843, 0.4667, 0.3961], [ 0.8196, 0.7569, 0.7569, ..., 0.0745, -0.0510, -0.1922]], [[-0.8039, -0.8118, -0.8118, ..., -0.8902, -0.8902, -0.8980], [-0.7882, -0.7882, -0.7882, ..., -0.8745, -0.8745, -0.8824], [-0.8118, -0.8039, -0.7882, ..., -0.8902, -0.8902, -0.8902], ..., [-0.2706, -0.3176, -0.3647, ..., -0.4275, -0.4588, -0.4824], [-0.2706, -0.2941, -0.3412, ..., -0.4824, -0.5451, -0.5765], [-0.2784, -0.3412, -0.3490, ..., -0.7333, -0.7804, -0.8353]], [[-0.5451, -0.4667, -0.4824, ..., -0.7412, -0.6941, -0.7176], [-0.5529, -0.5137, -0.4902, ..., -0.7412, -0.7098, -0.7412], [-0.5216, -0.4824, -0.4667, ..., -0.7490, -0.7490, -0.7647], ..., [ 0.5686, 0.5529, 0.4510, ..., 0.4431, 0.3882, 0.3255], [ 0.5451, 0.4902, 0.5137, ..., 0.3020, 0.2078, 0.1294], [ 0.5686, 0.5608, 0.5137, ..., -0.2000, -0.4275, -0.5294]]]])}torch.Size([1, 3, 224, 224]){0: 'tench, Tinca tinca', 1: 'goldfish, Carassius auratus', 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 3: 'tiger shark, Galeocerdo cuvieri', 4: 'hammerhead, hammerhead shark', 5: 'electric ray, crampfish, numbfish, torpedo', 6: 'stingray', 7: 'co*ck', 8: 'hen', 9: 'ostrich, Struthio camelus', 10: 'brambling, Fringilla montifringilla', 11: 'goldfinch, Carduelis carduelis', 12: 'house finch, linnet, Carpodacus mexicanus', 13: 'junco, snowbird', 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', 15: 'robin, American robin, Turdus migratorius', 16: 'bulbul', 17: 'jay', 18: 'magpie', 19: 'chickadee', 20: 'water ouzel, dipper', 21: 'kite', 22: 'bald eagle, American eagle, Haliaeetus leucocephalus', 23: 'vulture', 24: 'great grey owl, great gray owl, Strix nebulosa', 25: 'European fire salamander, Salamandra salamandra', 26: 'common newt, Triturus vulgaris', 27: 'eft', 28: 'spotted salamander, Ambystoma maculatum', 29: 'axolotl, mud puppy, Ambystoma mexicanum', 30: 'bullfrog, Rana catesbeiana', 31: 'tree frog, tree-frog', 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', 33: 'loggerhead, loggerhead turtle, Caretta caretta', 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', 35: 'mud turtle', 36: 'terrapin', 37: 'box turtle, box tortoise', 38: 'banded gecko', 39: 'common iguana, iguana, Iguana iguana', 40: 'American chameleon, anole, Anolis carolinensis', 41: 'whiptail, whiptail lizard', 42: 'agama', 43: 'frilled lizard, Chlamydosaurus kingi', 44: 'alligator lizard', 45: 'Gila monster, Heloderma suspectum', 46: 'green lizard, Lacerta viridis', 47: 'African chameleon, Chamaeleo chamaeleon', 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', 49: 'African crocodile, Nile crocodile, Crocodylus niloticus', 50: 'American alligator, Alligator mississipiensis', 51: 'triceratops', 52: 'thunder snake, worm snake, Carphophis amoenus', 53: 'ringneck snake, ring-necked snake, ring snake', 54: 'hognose snake, puff adder, sand viper', 55: 'green snake, grass snake', 56: 'king snake, kingsnake', 57: 'garter snake, grass snake', 58: 'water snake', 59: 'vine snake', 60: 'night snake, Hypsiglena torquata', 61: 'boa constrictor, Constrictor constrictor', 62: 'rock python, rock snake, Python sebae', 63: 'Indian cobra, Naja naja', 64: 'green mamba', 65: 'sea snake', 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus', 68: 'sidewinder, horned rattlesnake, Crotalus cerastes', 69: 'trilobite', 70: 'harvestman, daddy longlegs, Phalangium opilio', 71: 'scorpion', 72: 'black and gold garden spider, Argiope aurantia', 73: 'barn spider, Araneus cavaticus', 74: 'garden spider, Aranea diademata', 75: 'black widow, Latrodectus mactans', 76: 'tarantula', 77: 'wolf spider, hunting spider', 78: 'tick', 79: 'centipede', 80: 'black grouse', 81: 'ptarmigan', 82: 'ruffed grouse, partridge, Bonasa umbellus', 83: 'prairie chicken, prairie grouse, prairie fowl', 84: 'peaco*ck', 85: 'quail', 86: 'partridge', 87: 'African grey, African gray, Psittacus erithacus', 88: 'macaw', 89: 'sulphur-crested co*ckatoo, Kakatoe galerita, Cacatua galerita', 90: 'lorikeet', 91: 'coucal', 92: 'bee eater', 93: 'hornbill', 94: 'hummingbird', 95: 'jacamar', 96: 'toucan', 97: 'drake', 98: 'red-breasted merganser, Mergus serrator', 99: 'goose', 100: 'black swan, Cygnus atratus', 101: 'tusker', 102: 'echidna, spiny anteater, anteater', 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', 104: 'wallaby, brush kangaroo', 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', 106: 'wombat', 107: 'jellyfish', 108: 'sea anemone, anemone', 109: 'brain coral', 110: 'flatworm, platyhelminth', 111: 'nematode, nematode worm, roundworm', 112: 'conch', 113: 'snail', 114: 'slug', 115: 'sea slug, nudibranch', 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore', 117: 'chambered nautilus, pearly nautilus, nautilus', 118: 'Dungeness crab, Cancer magister', 119: 'rock crab, Cancer irroratus', 120: 'fiddler crab', 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus', 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', 124: 'crayfish, crawfish, crawdad, crawdaddy', 125: 'hermit crab', 126: 'isopod', 127: 'white stork, Ciconia ciconia', 128: 'black stork, Ciconia nigra', 129: 'spoonbill', 130: 'flamingo', 131: 'little blue heron, Egretta caerulea', 132: 'American egret, great white heron, Egretta albus', 133: 'bittern', 134: 'crane', 135: 'limpkin, Aramus pictus', 136: 'European gallinule, Porphyrio porphyrio', 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana', 138: 'bustard', 139: 'ruddy turnstone, Arenaria interpres', 140: 'red-backed sandpiper, dunlin, Erolia alpina', 141: 'redshank, Tringa totanus', 142: 'dowitcher', 143: 'oystercatcher, oyster catcher', 144: 'pelican', 145: 'king penguin, Aptenodytes patagonica', 146: 'albatross, mollymawk', 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', 149: 'dugong, Dugong dugon', 150: 'sea lion', 151: 'Chihuahua', 152: 'Japanese spaniel', 153: 'Maltese dog, Maltese terrier, Maltese', 154: 'Pekinese, Pekingese, Peke', 155: 'Shih-Tzu', 156: 'Blenheim spaniel', 157: 'papillon', 158: 'toy terrier', 159: 'Rhodesian ridgeback', 160: 'Afghan hound, Afghan', 161: 'basset, basset hound', 162: 'beagle', 163: 'bloodhound, sleuthhound', 164: 'bluetick', 165: 'black-and-tan coonhound', 166: 'Walker hound, Walker foxhound', 167: 'English foxhound', 168: 'redbone', 169: 'borzoi, Russian wolfhound', 170: 'Irish wolfhound', 171: 'Italian greyhound', 172: 'whippet', 173: 'Ibizan hound, Ibizan Podenco', 174: 'Norwegian elkhound, elkhound', 175: 'otterhound, otter hound', 176: 'Saluki, gazelle hound', 177: 'Scottish deerhound, deerhound', 178: 'Weimaraner', 179: 'Staffordshire bullterrier, Staffordshire bull terrier', 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', 181: 'Bedlington terrier', 182: 'Border terrier', 183: 'Kerry blue terrier', 184: 'Irish terrier', 185: 'Norfolk terrier', 186: 'Norwich terrier', 187: 'Yorkshire terrier', 188: 'wire-haired fox terrier', 189: 'Lakeland terrier', 190: 'Sealyham terrier, Sealyham', 191: 'Airedale, Airedale terrier', 192: 'cairn, cairn terrier', 193: 'Australian terrier', 194: 'Dandie Dinmont, Dandie Dinmont terrier', 195: 'Boston bull, Boston terrier', 196: 'miniature schnauzer', 197: 'giant schnauzer', 198: 'standard schnauzer', 199: 'Scotch terrier, Scottish terrier, Scottie', 200: 'Tibetan terrier, chrysanthemum dog', 201: 'silky terrier, Sydney silky', 202: 'soft-coated wheaten terrier', 203: 'West Highland white terrier', 204: 'Lhasa, Lhasa apso', 205: 'flat-coated retriever', 206: 'curly-coated retriever', 207: 'golden retriever', 208: 'Labrador retriever', 209: 'Chesapeake Bay retriever', 210: 'German short-haired pointer', 211: 'vizsla, Hungarian pointer', 212: 'English setter', 213: 'Irish setter, red setter', 214: 'Gordon setter', 215: 'Brittany spaniel', 216: 'clumber, clumber spaniel', 217: 'English springer, English springer spaniel', 218: 'Welsh springer spaniel', 219: 'co*cker spaniel, English co*cker spaniel, co*cker', 220: 'Sussex spaniel', 221: 'Irish water spaniel', 222: 'kuvasz', 223: 'schipperke', 224: 'groenendael', 225: 'malinois', 226: 'briard', 227: 'kelpie', 228: 'komondor', 229: 'Old English sheepdog, bobtail', 230: 'Shetland sheepdog, Shetland sheep dog, Shetland', 231: 'collie', 232: 'Border collie', 233: 'Bouvier des Flandres, Bouviers des Flandres', 234: 'Rottweiler', 235: 'German shepherd, German shepherd dog, German police dog, alsatian', 236: 'Doberman, Doberman pinscher', 237: 'miniature pinscher', 238: 'Greater Swiss Mountain dog', 239: 'Bernese mountain dog', 240: 'Appenzeller', 241: 'EntleBucher', 242: 'boxer', 243: 'bull mastiff', 244: 'Tibetan mastiff', 245: 'French bulldog', 246: 'Great Dane', 247: 'Saint Bernard, St Bernard', 248: 'Eskimo dog, husky', 249: 'malamute, malemute, Alaskan malamute', 250: 'Siberian husky', 251: 'dalmatian, coach dog, carriage dog', 252: 'affenpinscher, monkey pinscher, monkey dog', 253: 'basenji', 254: 'pug, pug-dog', 255: 'Leonberg', 256: 'Newfoundland, Newfoundland dog', 257: 'Great Pyrenees', 258: 'Samoyed, Samoyede', 259: 'Pomeranian', 260: 'chow, chow chow', 261: 'keeshond', 262: 'Brabancon griffon', 263: 'Pembroke, Pembroke Welsh corgi', 264: 'Cardigan, Cardigan Welsh corgi', 265: 'toy poodle', 266: 'miniature poodle', 267: 'standard poodle', 268: 'Mexican hairless', 269: 'timber wolf, grey wolf, gray wolf, Canis lupus', 270: 'white wolf, Arctic wolf, Canis lupus tundrarum', 271: 'red wolf, maned wolf, Canis rufus, Canis niger', 272: 'coyote, prairie wolf, brush wolf, Canis latrans', 273: 'dingo, warrigal, warragal, Canis dingo', 274: 'dhole, Cuon alpinus', 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', 276: 'hyena, hyaena', 277: 'red fox, Vulpes vulpes', 278: 'kit fox, Vulpes macrotis', 279: 'Arctic fox, white fox, Alopex lagopus', 280: 'grey fox, gray fox, Urocyon cinereoargenteus', 281: 'tabby, tabby cat', 282: 'tiger cat', 283: 'Persian cat', 284: 'Siamese cat, Siamese', 285: 'Egyptian cat', 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', 287: 'lynx, catamount', 288: 'leopard, Panthera pardus', 289: 'snow leopard, ounce, Panthera uncia', 290: 'jaguar, panther, Panthera onca, Felis onca', 291: 'lion, king of beasts, Panthera leo', 292: 'tiger, Panthera tigris', 293: 'cheetah, chetah, Acinonyx jubatus', 294: 'brown bear, bruin, Ursus arctos', 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus', 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', 297: 'sloth bear, Melursus ursinus, Ursus ursinus', 298: 'mongoose', 299: 'meerkat, mierkat', 300: 'tiger beetle', 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', 302: 'ground beetle, carabid beetle', 303: 'long-horned beetle, longicorn, longicorn beetle', 304: 'leaf beetle, chrysomelid', 305: 'dung beetle', 306: 'rhinoceros beetle', 307: 'weevil', 308: 'fly', 309: 'bee', 310: 'ant, emmet, pismire', 311: 'grasshopper, hopper', 312: 'cricket', 313: 'walking stick, walkingstick, stick insect', 314: 'co*ckroach, roach', 315: 'mantis, mantid', 316: 'cicada, cicala', 317: 'leafhopper', 318: 'lacewing, lacewing fly', 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", 320: 'damselfly', 321: 'admiral', 322: 'ringlet, ringlet butterfly', 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', 324: 'cabbage butterfly', 325: 'sulphur butterfly, sulfur butterfly', 326: 'lycaenid, lycaenid butterfly', 327: 'starfish, sea star', 328: 'sea urchin', 329: 'sea cucumber, holothurian', 330: 'wood rabbit, cottontail, cottontail rabbit', 331: 'hare', 332: 'Angora, Angora rabbit', 333: 'hamster', 334: 'porcupine, hedgehog', 335: 'fox squirrel, eastern fox squirrel, Sciurus niger', 336: 'marmot', 337: 'beaver', 338: 'guinea pig, Cavia cobaya', 339: 'sorrel', 340: 'zebra', 341: 'hog, pig, grunter, squealer, Sus scrofa', 342: 'wild boar, boar, Sus scrofa', 343: 'warthog', 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius', 345: 'ox', 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', 347: 'bison', 348: 'ram, tup', 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', 350: 'ibex, Capra ibex', 351: 'hartebeest', 352: 'impala, Aepyceros melampus', 353: 'gazelle', 354: 'Arabian camel, dromedary, Camelus dromedarius', 355: 'llama', 356: 'weasel', 357: 'mink', 358: 'polecat, fitch, foulmart, foumart, Mustela putorius', 359: 'black-footed ferret, ferret, Mustela nigripes', 360: 'otter', 361: 'skunk, polecat, wood puss*', 362: 'badger', 363: 'armadillo', 364: 'three-toed sloth, ai, Bradypus tridactylus', 365: 'orangutan, orang, orangutang, Pongo pygmaeus', 366: 'gorilla, Gorilla gorilla', 367: 'chimpanzee, chimp, Pan troglodytes', 368: 'gibbon, Hylobates lar', 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus', 370: 'guenon, guenon monkey', 371: 'patas, hussar monkey, Erythrocebus patas', 372: 'baboon', 373: 'macaque', 374: 'langur', 375: 'colobus, colobus monkey', 376: 'proboscis monkey, Nasalis larvatus', 377: 'marmoset', 378: 'capuchin, ringtail, Cebus capucinus', 379: 'howler monkey, howler', 380: 'titi, titi monkey', 381: 'spider monkey, Ateles geoffroyi', 382: 'squirrel monkey, Saimiri sciureus', 383: 'Madagascar cat, ring-tailed lemur, Lemur catta', 384: 'indri, indris, Indri indri, Indri brevicaudatus', 385: 'Indian elephant, Elephas maximus', 386: 'African elephant, Loxodonta africana', 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', 389: 'barracouta, snoek', 390: 'eel', 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', 392: 'rock beauty, Holocanthus tricolor', 393: 'anemone fish', 394: 'sturgeon', 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus', 396: 'lionfish', 397: 'puffer, pufferfish, blowfish, globefish', 398: 'abacus', 399: 'abaya', 400: "academic gown, academic robe, judge's robe", 401: 'accordion, piano accordion, squeeze box', 402: 'acoustic guitar', 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier', 404: 'airliner', 405: 'airship, dirigible', 406: 'altar', 407: 'ambulance', 408: 'amphibian, amphibious vehicle', 409: 'analog clock', 410: 'apiary, bee house', 411: 'apron', 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', 413: 'assault rifle, assault gun', 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack', 415: 'bakery, bakeshop, bakehouse', 416: 'balance beam, beam', 417: 'balloon', 418: 'ballpoint, ballpoint pen, ballpen, Biro', 419: 'Band Aid', 420: 'banjo', 421: 'bannister, banister, balustrade, balusters, handrail', 422: 'barbell', 423: 'barber chair', 424: 'barbershop', 425: 'barn', 426: 'barometer', 427: 'barrel, cask', 428: 'barrow, garden cart, lawn cart, wheelbarrow', 429: 'baseball', 430: 'basketball', 431: 'bassinet', 432: 'bassoon', 433: 'bathing cap, swimming cap', 434: 'bath towel', 435: 'bathtub, bathing tub, bath, tub', 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', 437: 'beacon, lighthouse, beacon light, pharos', 438: 'beaker', 439: 'bearskin, busby, shako', 440: 'beer bottle', 441: 'beer glass', 442: 'bell cote, bell cot', 443: 'bib', 444: 'bicycle-built-for-two, tandem bicycle, tandem', 445: 'bikini, two-piece', 446: 'binder, ring-binder', 447: 'binoculars, field glasses, opera glasses', 448: 'birdhouse', 449: 'boathouse', 450: 'bobsled, bobsleigh, bob', 451: 'bolo tie, bolo, bola tie, bola', 452: 'bonnet, poke bonnet', 453: 'bookcase', 454: 'bookshop, bookstore, bookstall', 455: 'bottlecap', 456: 'bow', 457: 'bow tie, bow-tie, bowtie', 458: 'brass, memorial tablet, plaque', 459: 'brassiere, bra, bandeau', 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', 461: 'breastplate, aegis, egis', 462: 'broom', 463: 'bucket, pail', 464: 'buckle', 465: 'bulletproof vest', 466: 'bullet train, bullet', 467: 'butcher shop, meat market', 468: 'cab, hack, taxi, taxicab', 469: 'caldron, cauldron', 470: 'candle, taper, wax light', 471: 'cannon', 472: 'canoe', 473: 'can opener, tin opener', 474: 'cardigan', 475: 'car mirror', 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig', 477: "carpenter's kit, tool kit", 478: 'carton', 479: 'car wheel', 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', 481: 'cassette', 482: 'cassette player', 483: 'castle', 484: 'catamaran', 485: 'CD player', 486: 'cello, violoncello', 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone', 488: 'chain', 489: 'chainlink fence', 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', 491: 'chain saw, chainsaw', 492: 'chest', 493: 'chiffonier, commode', 494: 'chime, bell, gong', 495: 'china cabinet, china closet', 496: 'Christmas stocking', 497: 'church, church building', 498: 'cinema, movie theater, movie theatre, movie house, picture palace', 499: 'cleaver, meat cleaver, chopper', 500: 'cliff dwelling', 501: 'cloak', 502: 'clog, geta, patten, sabot', 503: 'co*cktail shaker', 504: 'coffee mug', 505: 'coffeepot', 506: 'coil, spiral, volute, whorl, helix', 507: 'combination lock', 508: 'computer keyboard, keypad', 509: 'confectionery, confectionary, candy store', 510: 'container ship, containership, container vessel', 511: 'convertible', 512: 'corkscrew, bottle screw', 513: 'cornet, horn, trumpet, trump', 514: 'cowboy boot', 515: 'cowboy hat, ten-gallon hat', 516: 'cradle', 517: 'crane', 518: 'crash helmet', 519: 'crate', 520: 'crib, cot', 521: 'Crock Pot', 522: 'croquet ball', 523: 'crutch', 524: 'cuirass', 525: 'dam, dike, dyke', 526: 'desk', 527: 'desktop computer', 528: 'dial telephone, dial phone', 529: 'diaper, nappy, napkin', 530: 'digital clock', 531: 'digital watch', 532: 'dining table, board', 533: 'dishrag, dishcloth', 534: 'dishwasher, dish washer, dishwashing machine', 535: 'disk brake, disc brake', 536: 'dock, dockage, docking facility', 537: 'dogsled, dog sled, dog sleigh', 538: 'dome', 539: 'doormat, welcome mat', 540: 'drilling platform, offshore rig', 541: 'drum, membranophone, tympan', 542: 'drumstick', 543: 'dumbbell', 544: 'Dutch oven', 545: 'electric fan, blower', 546: 'electric guitar', 547: 'electric locomotive', 548: 'entertainment center', 549: 'envelope', 550: 'espresso maker', 551: 'face powder', 552: 'feather boa, boa', 553: 'file, file cabinet, filing cabinet', 554: 'fireboat', 555: 'fire engine, fire truck', 556: 'fire screen, fireguard', 557: 'flagpole, flagstaff', 558: 'flute, transverse flute', 559: 'folding chair', 560: 'football helmet', 561: 'forklift', 562: 'fountain', 563: 'fountain pen', 564: 'four-poster', 565: 'freight car', 566: 'French horn, horn', 567: 'frying pan, frypan, skillet', 568: 'fur coat', 569: 'garbage truck, dustcart', 570: 'gasmask, respirator, gas helmet', 571: 'gas pump, gasoline pump, petrol pump, island dispenser', 572: 'goblet', 573: 'go-kart', 574: 'golf ball', 575: 'golfcart, golf cart', 576: 'gondola', 577: 'gong, tam-tam', 578: 'gown', 579: 'grand piano, grand', 580: 'greenhouse, nursery, glasshouse', 581: 'grille, radiator grille', 582: 'grocery store, grocery, food market, market', 583: 'guillotine', 584: 'hair slide', 585: 'hair spray', 586: 'half track', 587: 'hammer', 588: 'hamper', 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier', 590: 'hand-held computer, hand-held microcomputer', 591: 'handkerchief, hankie, hanky, hankey', 592: 'hard disc, hard disk, fixed disk', 593: 'harmonica, mouth organ, harp, mouth harp', 594: 'harp', 595: 'harvester, reaper', 596: 'hatchet', 597: 'holster', 598: 'home theater, home theatre', 599: 'honeycomb', 600: 'hook, claw', 601: 'hoopskirt, crinoline', 602: 'horizontal bar, high bar', 603: 'horse cart, horse-cart', 604: 'hourglass', 605: 'iPod', 606: 'iron, smoothing iron', 607: "jack-o'-lantern", 608: 'jean, blue jean, denim', 609: 'jeep, landrover', 610: 'jersey, T-shirt, tee shirt', 611: 'jigsaw puzzle', 612: 'jinrikisha, ricksha, rickshaw', 613: 'joystick', 614: 'kimono', 615: 'knee pad', 616: 'knot', 617: 'lab coat, laboratory coat', 618: 'ladle', 619: 'lampshade, lamp shade', 620: 'laptop, laptop computer', 621: 'lawn mower, mower', 622: 'lens cap, lens cover', 623: 'letter opener, paper knife, paperknife', 624: 'library', 625: 'lifeboat', 626: 'lighter, light, igniter, ignitor', 627: 'limousine, limo', 628: 'liner, ocean liner', 629: 'lipstick, lip rouge', 630: 'Loafer', 631: 'lotion', 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', 633: "loupe, jeweler's loupe", 634: 'lumbermill, sawmill', 635: 'magnetic compass', 636: 'mailbag, postbag', 637: 'mailbox, letter box', 638: 'maillot', 639: 'maillot, tank suit', 640: 'manhole cover', 641: 'maraca', 642: 'marimba, xylophone', 643: 'mask', 644: 'matchstick', 645: 'maypole', 646: 'maze, labyrinth', 647: 'measuring cup', 648: 'medicine chest, medicine cabinet', 649: 'megalith, megalithic structure', 650: 'microphone, mike', 651: 'microwave, microwave oven', 652: 'military uniform', 653: 'milk can', 654: 'minibus', 655: 'miniskirt, mini', 656: 'minivan', 657: 'missile', 658: 'mitten', 659: 'mixing bowl', 660: 'mobile home, manufactured home', 661: 'Model T', 662: 'modem', 663: 'monastery', 664: 'monitor', 665: 'moped', 666: 'mortar', 667: 'mortarboard', 668: 'mosque', 669: 'mosquito net', 670: 'motor scooter, scooter', 671: 'mountain bike, all-terrain bike, off-roader', 672: 'mountain tent', 673: 'mouse, computer mouse', 674: 'mousetrap', 675: 'moving van', 676: 'muzzle', 677: 'nail', 678: 'neck brace', 679: 'necklace', 680: 'nipple', 681: 'notebook, notebook computer', 682: 'obelisk', 683: 'oboe, hautboy, hautbois', 684: 'ocarina, sweet potato', 685: 'odometer, hodometer, mileometer, milometer', 686: 'oil filter', 687: 'organ, pipe organ', 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO', 689: 'overskirt', 690: 'oxcart', 691: 'oxygen mask', 692: 'packet', 693: 'paddle, boat paddle', 694: 'paddlewheel, paddle wheel', 695: 'padlock', 696: 'paintbrush', 697: "pajama, pyjama, pj's, jammies", 698: 'palace', 699: 'panpipe, pandean pipe, syrinx', 700: 'paper towel', 701: 'parachute, chute', 702: 'parallel bars, bars', 703: 'park bench', 704: 'parking meter', 705: 'passenger car, coach, carriage', 706: 'patio, terrace', 707: 'pay-phone, pay-station', 708: 'pedestal, plinth, footstall', 709: 'pencil box, pencil case', 710: 'pencil sharpener', 711: 'perfume, essence', 712: 'Petri dish', 713: 'photocopier', 714: 'pick, plectrum, plectron', 715: 'pickelhaube', 716: 'picket fence, paling', 717: 'pickup, pickup truck', 718: 'pier', 719: 'piggy bank, penny bank', 720: 'pill bottle', 721: 'pillow', 722: 'ping-pong ball', 723: 'pinwheel', 724: 'pirate, pirate ship', 725: 'pitcher, ewer', 726: "plane, carpenter's plane, woodworking plane", 727: 'planetarium', 728: 'plastic bag', 729: 'plate rack', 730: 'plow, plough', 731: "plunger, plumber's helper", 732: 'Polaroid camera, Polaroid Land camera', 733: 'pole', 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', 735: 'poncho', 736: 'pool table, billiard table, snooker table', 737: 'pop bottle, soda bottle', 738: 'pot, flowerpot', 739: "potter's wheel", 740: 'power drill', 741: 'prayer rug, prayer mat', 742: 'printer', 743: 'prison, prison house', 744: 'projectile, missile', 745: 'projector', 746: 'puck, hockey puck', 747: 'punching bag, punch bag, punching ball, punchball', 748: 'purse', 749: 'quill, quill pen', 750: 'quilt, comforter, comfort, puff', 751: 'racer, race car, racing car', 752: 'racket, racquet', 753: 'radiator', 754: 'radio, wireless', 755: 'radio telescope, radio reflector', 756: 'rain barrel', 757: 'recreational vehicle, RV, R.V.', 758: 'reel', 759: 'reflex camera', 760: 'refrigerator, icebox', 761: 'remote control, remote', 762: 'restaurant, eating house, eating place, eatery', 763: 'revolver, six-gun, six-shooter', 764: 'rifle', 765: 'rocking chair, rocker', 766: 'rotisserie', 767: 'rubber eraser, rubber, pencil eraser', 768: 'rugby ball', 769: 'rule, ruler', 770: 'running shoe', 771: 'safe', 772: 'safety pin', 773: 'saltshaker, salt shaker', 774: 'sandal', 775: 'sarong', 776: 'sax, saxophone', 777: 'scabbard', 778: 'scale, weighing machine', 779: 'school bus', 780: 'schooner', 781: 'scoreboard', 782: 'screen, CRT screen', 783: 'screw', 784: 'screwdriver', 785: 'seat belt, seatbelt', 786: 'sewing machine', 787: 'shield, buckler', 788: 'shoe shop, shoe-shop, shoe store', 789: 'shoji', 790: 'shopping basket', 791: 'shopping cart', 792: 'shovel', 793: 'shower cap', 794: 'shower curtain', 795: 'ski', 796: 'ski mask', 797: 'sleeping bag', 798: 'slide rule, slipstick', 799: 'sliding door', 800: 'slot, one-armed bandit', 801: 'snorkel', 802: 'snowmobile', 803: 'snowplow, snowplough', 804: 'soap dispenser', 805: 'soccer ball', 806: 'sock', 807: 'solar dish, solar collector, solar furnace', 808: 'sombrero', 809: 'soup bowl', 810: 'space bar', 811: 'space heater', 812: 'space shuttle', 813: 'spatula', 814: 'speedboat', 815: "spider web, spider's web", 816: 'spindle', 817: 'sports car, sport car', 818: 'spotlight, spot', 819: 'stage', 820: 'steam locomotive', 821: 'steel arch bridge', 822: 'steel drum', 823: 'stethoscope', 824: 'stole', 825: 'stone wall', 826: 'stopwatch, stop watch', 827: 'stove', 828: 'strainer', 829: 'streetcar, tram, tramcar, trolley, trolley car', 830: 'stretcher', 831: 'studio couch, day bed', 832: 'stupa, tope', 833: 'submarine, pigboat, sub, U-boat', 834: 'suit, suit of clothes', 835: 'sundial', 836: 'sunglass', 837: 'sunglasses, dark glasses, shades', 838: 'sunscreen, sunblock, sun blocker', 839: 'suspension bridge', 840: 'swab, swob, mop', 841: 'sweatshirt', 842: 'swimming trunks, bathing trunks', 843: 'swing', 844: 'switch, electric switch, electrical switch', 845: 'syringe', 846: 'table lamp', 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle', 848: 'tape player', 849: 'teapot', 850: 'teddy, teddy bear', 851: 'television, television system', 852: 'tennis ball', 853: 'thatch, thatched roof', 854: 'theater curtain, theatre curtain', 855: 'thimble', 856: 'thresher, thrasher, threshing machine', 857: 'throne', 858: 'tile roof', 859: 'toaster', 860: 'tobacco shop, tobacconist shop, tobacconist', 861: 'toilet seat', 862: 'torch', 863: 'totem pole', 864: 'tow truck, tow car, wrecker', 865: 'toyshop', 866: 'tractor', 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', 868: 'tray', 869: 'trench coat', 870: 'tricycle, trike, velocipede', 871: 'trimaran', 872: 'tripod', 873: 'triumphal arch', 874: 'trolleybus, trolley coach, trackless trolley', 875: 'trombone', 876: 'tub, vat', 877: 'turnstile', 878: 'typewriter keyboard', 879: 'umbrella', 880: 'unicycle, monocycle', 881: 'upright, upright piano', 882: 'vacuum, vacuum cleaner', 883: 'vase', 884: 'vault', 885: 'velvet', 886: 'vending machine', 887: 'vestment', 888: 'viaduct', 889: 'violin, fiddle', 890: 'volleyball', 891: 'waffle iron', 892: 'wall clock', 893: 'wallet, billfold, notecase, pocketbook', 894: 'wardrobe, closet, press', 895: 'warplane, military plane', 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', 897: 'washer, automatic washer, washing machine', 898: 'water bottle', 899: 'water jug', 900: 'water tower', 901: 'whiskey jug', 902: 'whistle', 903: 'wig', 904: 'window screen', 905: 'window shade', 906: 'Windsor tie', 907: 'wine bottle', 908: 'wing', 909: 'wok', 910: 'wooden spoon', 911: 'wool, woolen, woollen', 912: 'worm fence, snake fence, snake-rail fence, Virginia fence', 913: 'wreck', 914: 'yawl', 915: 'yurt', 916: 'web site, website, internet site, site', 917: 'comic book', 918: 'crossword puzzle, crossword', 919: 'street sign', 920: 'traffic light, traffic signal, stoplight', 921: 'book jacket, dust cover, dust jacket, dust wrapper', 922: 'menu', 923: 'plate', 924: 'guacamole', 925: 'consomme', 926: 'hot pot, hotpot', 927: 'trifle', 928: 'ice cream, icecream', 929: 'ice lolly, lolly, lollipop, popsicle', 930: 'French loaf', 931: 'bagel, beigel', 932: 'pretzel', 933: 'cheeseburger', 934: 'hotdog, hot dog, red hot', 935: 'mashed potato', 936: 'head cabbage', 937: 'broccoli', 938: 'cauliflower', 939: 'zucchini, courgette', 940: 'spaghetti squash', 941: 'acorn squash', 942: 'butternut squash', 943: 'cucumber, cuke', 944: 'artichoke, globe artichoke', 945: 'bell pepper', 946: 'cardoon', 947: 'mushroom', 948: 'Granny Smith', 949: 'strawberry', 950: 'orange', 951: 'lemon', 952: 'fig', 953: 'pineapple, ananas', 954: 'banana', 955: 'jackfruit, jak, jack', 956: 'custard apple', 957: 'pomegranate', 958: 'hay', 959: 'carbonara', 960: 'chocolate sauce, chocolate syrup', 961: 'dough', 962: 'meat loaf, meatloaf', 963: 'pizza, pizza pie', 964: 'potpie', 965: 'burrito', 966: 'red wine', 967: 'espresso', 968: 'cup', 969: 'eggnog', 970: 'alp', 971: 'bubble', 972: 'cliff, drop, drop-off', 973: 'coral reef', 974: 'geyser', 975: 'lakeside, lakeshore', 976: 'promontory, headland, head, foreland', 977: 'sandbar, sand bar', 978: 'seashore, coast, seacoast, sea-coast', 979: 'valley, vale', 980: 'volcano', 981: 'ballplayer, baseball player', 982: 'groom, bridegroom', 983: 'scuba diver', 984: 'rapeseed', 985: 'daisy', 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", 987: 'corn', 988: 'acorn', 989: 'hip, rose hip, rosehip', 990: 'buckeye, horse chestnut, conker', 991: 'coral fungus', 992: 'agaric', 993: 'gyromitra', 994: 'stinkhorn, carrion fungus', 995: 'earthstar', 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', 997: 'bolete', 998: 'ear, spike, capitulum', 999: 'toilet tissue, toilet paper, bathroom tissue'}Egyptian cat

经过转码后得json结构是一个key=pixel_values的像素数组,维度是:[批次,通道数,宽度,高度]。
通过model.config.id2label可以看到总共1000个label。

数据集

food101包含多种食物类别,数据集地址:https://huggingface.co/datasets/ethz/food101。

from datasets import load_datasetds = load_dataset("food101")print("数据集",ds)#获取训练集数据ds = load_dataset("food101",split="train")print("训练集",ds)print("第一个数据集",ds[0])#获取所有labellabels = ds.features["label"].namesprint(labels)print(len(labels))

输出

数据集 DatasetDict({ train: Dataset({ features: ['image', 'label'], num_rows: 75750 }) validation: Dataset({ features: ['image', 'label'], num_rows: 25250 })})训练集 Dataset({ features: ['image', 'label'], num_rows: 75750})第一个数据集 {'image': <PIL.Image.Image image mode=RGB size=384x512 at 0x7A1FCE415750>, 'label': 6}['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']101

总共101个food总类。
显示第二张图片和label

import matplotlib.pyplot as pltplt.imshow(ds[1]["image"])plt.axis('off') # 关闭坐标轴plt.show()print(labels[ds[1]["label"]])

显示
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (1)
这里我们看到第二个数据集的label=6,也就是beignets。

我们需要生成label和id关系的字典。

labels = ds["train"].features["label"].nameslabel2id, id2label = dict(), dict()for i, label in enumerate(labels): label2id[label] = i id2label[i] = labelprint(id2label)print(label2id)

输出:

{0: 'apple_pie', 1: 'baby_back_ribs', 2: 'baklava', 3: 'beef_carpaccio', 4: 'beef_tartare', 5: 'beet_salad', 6: 'beignets', 7: 'bibimbap', 8: 'bread_pudding', 9: 'breakfast_burrito', 10: 'bruschetta', 11: 'caesar_salad', 12: 'cannoli', 13: 'caprese_salad', 14: 'carrot_cake', 15: 'ceviche', 16: 'cheesecake', 17: 'cheese_plate', 18: 'chicken_curry', 19: 'chicken_quesadilla', 20: 'chicken_wings', 21: 'chocolate_cake', 22: 'chocolate_mousse', 23: 'churros', 24: 'clam_chowder', 25: 'club_sandwich', 26: 'crab_cakes', 27: 'creme_brulee', 28: 'croque_madame', 29: 'cup_cakes', 30: 'deviled_eggs', 31: 'donuts', 32: 'dumplings', 33: 'edamame', 34: 'eggs_benedict', 35: 'escargots', 36: 'falafel', 37: 'filet_mignon', 38: 'fish_and_chips', 39: 'foie_gras', 40: 'french_fries', 41: 'french_onion_soup', 42: 'french_toast', 43: 'fried_calamari', 44: 'fried_rice', 45: 'frozen_yogurt', 46: 'garlic_bread', 47: 'gnocchi', 48: 'greek_salad', 49: 'grilled_cheese_sandwich', 50: 'grilled_salmon', 51: 'guacamole', 52: 'gyoza', 53: 'hamburger', 54: 'hot_and_sour_soup', 55: 'hot_dog', 56: 'huevos_rancheros', 57: 'hummus', 58: 'ice_cream', 59: 'lasagna', 60: 'lobster_bisque', 61: 'lobster_roll_sandwich', 62: 'macaroni_and_cheese', 63: 'macarons', 64: 'miso_soup', 65: 'mussels', 66: 'nachos', 67: 'omelette', 68: 'onion_rings', 69: 'oysters', 70: 'pad_thai', 71: 'paella', 72: 'pancakes', 73: 'panna_cotta', 74: 'peking_duck', 75: 'pho', 76: 'pizza', 77: 'pork_chop', 78: 'poutine', 79: 'prime_rib', 80: 'pulled_pork_sandwich', 81: 'ramen', 82: 'ravioli', 83: 'red_velvet_cake', 84: 'risotto', 85: 'samosa', 86: 'sashimi', 87: 'scallops', 88: 'seaweed_salad', 89: 'shrimp_and_grits', 90: 'spaghetti_bolognese', 91: 'spaghetti_carbonara', 92: 'spring_rolls', 93: 'steak', 94: 'strawberry_shortcake', 95: 'sushi', 96: 'tacos', 97: 'takoyaki', 98: 'tiramisu', 99: 'tuna_tartare', 100: 'waffles'}{'apple_pie': 0, 'baby_back_ribs': 1, 'baklava': 2, 'beef_carpaccio': 3, 'beef_tartare': 4, 'beet_salad': 5, 'beignets': 6, 'bibimbap': 7, 'bread_pudding': 8, 'breakfast_burrito': 9, 'bruschetta': 10, 'caesar_salad': 11, 'cannoli': 12, 'caprese_salad': 13, 'carrot_cake': 14, 'ceviche': 15, 'cheesecake': 16, 'cheese_plate': 17, 'chicken_curry': 18, 'chicken_quesadilla': 19, 'chicken_wings': 20, 'chocolate_cake': 21, 'chocolate_mousse': 22, 'churros': 23, 'clam_chowder': 24, 'club_sandwich': 25, 'crab_cakes': 26, 'creme_brulee': 27, 'croque_madame': 28, 'cup_cakes': 29, 'deviled_eggs': 30, 'donuts': 31, 'dumplings': 32, 'edamame': 33, 'eggs_benedict': 34, 'escargots': 35, 'falafel': 36, 'filet_mignon': 37, 'fish_and_chips': 38, 'foie_gras': 39, 'french_fries': 40, 'french_onion_soup': 41, 'french_toast': 42, 'fried_calamari': 43, 'fried_rice': 44, 'frozen_yogurt': 45, 'garlic_bread': 46, 'gnocchi': 47, 'greek_salad': 48, 'grilled_cheese_sandwich': 49, 'grilled_salmon': 50, 'guacamole': 51, 'gyoza': 52, 'hamburger': 53, 'hot_and_sour_soup': 54, 'hot_dog': 55, 'huevos_rancheros': 56, 'hummus': 57, 'ice_cream': 58, 'lasagna': 59, 'lobster_bisque': 60, 'lobster_roll_sandwich': 61, 'macaroni_and_cheese': 62, 'macarons': 63, 'miso_soup': 64, 'mussels': 65, 'nachos': 66, 'omelette': 67, 'onion_rings': 68, 'oysters': 69, 'pad_thai': 70, 'paella': 71, 'pancakes': 72, 'panna_cotta': 73, 'peking_duck': 74, 'pho': 75, 'pizza': 76, 'pork_chop': 77, 'poutine': 78, 'prime_rib': 79, 'pulled_pork_sandwich': 80, 'ramen': 81, 'ravioli': 82, 'red_velvet_cake': 83, 'risotto': 84, 'samosa': 85, 'sashimi': 86, 'scallops': 87, 'seaweed_salad': 88, 'shrimp_and_grits': 89, 'spaghetti_bolognese': 90, 'spaghetti_carbonara': 91, 'spring_rolls': 92, 'steak': 93, 'strawberry_shortcake': 94, 'sushi': 95, 'tacos': 96, 'takoyaki': 97, 'tiramisu': 98, 'tuna_tartare': 99, 'waffles': 100}

加载一个图像处理器,以正确调整大小并对训练和评估图像的像素值进行归一化。

from transformers import AutoImageProcessorimage_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

您还可以使用图像处理器来准备一些转换函数,用于数据增强和像素缩放。

from torchvision.transforms import ( CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor,)normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)train_transforms = Compose( [ RandomResizedCrop(image_processor.size["height"]), RandomHorizontalFlip(), ToTensor(), normalize, ])val_transforms = Compose( [ Resize(image_processor.size["height"]), CenterCrop(image_processor.size["height"]), ToTensor(), normalize, ])def preprocess_train(example_batch): example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]] return example_batchdef preprocess_val(example_batch): example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]] return example_batch

定义训练和验证数据集,并使用set_transform函数在运行时应用转换。

train_ds = ds["train"]val_ds = ds["validation"]train_ds.set_transform(preprocess_train)val_ds.set_transform(preprocess_val)

最后,您需要一个数据整理器来创建训练和评估数据的批次,并将标签转换为torch.tensor对象。

import torchdef collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) labels = torch.tensor([example["label"] for example in examples]) return {"pixel_values": pixel_values, "labels": labels}

模型

现在让我们加载一个预训练模型作为基础模型。本指南使用了google/vit-base-patch16-224-in21k模型,但您可以使用任何您想要的图像分类模型。将label2id和id2label字典传递给模型,以便它知道如何将整数标签映射到它们的类标签,并且如果您正在微调已经微调过的检查点,可以选择传递ignore_mismatched_sizes=True参数。

from transformers import AutoModelForImageClassification, TrainingArguments, Trainermodel = AutoModelForImageClassification.from_pretrained( "google/vit-base-patch16-224-in21k", label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True,)

PEFT configuration and model

每个 PEFT 方法都需要一个配置,其中包含了指定 PEFT 方法应该如何应用的所有参数。一旦配置设置好了,就将其传递给 get_peft_model() 函数,同时还要传递基础模型,以创建一个可训练的 PeftModel。

调用 print_trainable_parameters() 方法来比较 PeftModel 的参数数量与基础模型的参数数量!

LoRA将权重更新矩阵分解为两个较小的矩阵。这些低秩矩阵的大小由其秩或r确定。更高的秩意味着模型有更多的参数需要训练,但也意味着模型有更大的学习能力。您还需要指定 target_modules,确定较小矩阵插入的位置。对于本指南,您将针对注意力块的查询和值矩阵进行目标指定。设置的其他重要参数包括 lora_alpha(缩放因子)、bias(是否应该训练none、all或只有 LoRA 偏置参数)、modules_to_save(除了 LoRA 层之外需要训练和保存的模块)。所有这些参数 - 以及更多 - 都可以在 LoraConfig 中找到。

from peft import LoraConfig, get_peft_modelconfig = LoraConfig( r=16, lora_alpha=16, target_modules=["query", "value"], lora_dropout=0.1, bias="none", modules_to_save=["classifier"],)model = get_peft_model(model, config)model.print_trainable_parameters()

输出:“trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.7712775047664294”

在LoRA中,为了简化和精简,可能只针对查询和值矩阵进行权重分解,而不对键矩阵进行处理。这样可以在一定程度上减少计算量和参数数量,同时仍然提高模型的学习能力和灵活性。

训练

对于训练,让我们使用Transformers中的Trainer类。Trainer类包含一个PyTorch训练循环,当您准备好时,调用train开始训练。要自定义训练运行,请在TrainingArguments类中配置训练超参数。对于类似LoRA的方法,您可以承受更高的批量大小和学习率。

batch_size = 128args = TrainingArguments( #peft_model_id, output_dir="/kaggle/working", remove_unused_columns=False, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=5e-3, report_to="none", per_device_train_batch_size=batch_size, gradient_accumulation_steps=4, per_device_eval_batch_size=batch_size, fp16=True, num_train_epochs=5, logging_steps=10, load_best_model_at_end=True, label_names=["labels"],)

这里是对TrainingArguments中参数的解释:

  • output_dir: 指定训练过程中输出模型和日志的目录。
  • remove_unused_columns: 控制是否在训练过程中删除未使用的列。
  • evaluation_strategy: 指定评估策略,这里设置为“epoch”表示在每个epoch结束时进行评估。
  • save_strategy: 指定模型保存策略,这里设置为“epoch”表示在每个epoch结束时保存模型。
  • learning_rate: 学习率设置为5e-3,即0.005。
  • report_to: 控制训练过程中的报告输出,这里设置为“none”表示不输出任何报告。
  • per_device_train_batch_size: 每个设备上的训练批量大小。
  • gradient_accumulation_steps: 梯度累积步数。
  • per_device_eval_batch_size: 每个设备上的评估批量大小。
  • fp16: 控制是否使用混合精度训练。
  • num_train_epochs: 训练的总epoch数。
  • logging_steps: 控制日志输出的步数。
  • load_best_model_at_end: 在训练结束时是否加载最佳模型。
  • label_names: 标签的名称列表。

这些参数是用来配置训练过程的,例如指定训练和评估的批量大小、学习率、训练时长等等。
开始训练

trainer = Trainer( model, args, train_dataset=train_ds, eval_dataset=val_ds, tokenizer=image_processor, data_collator=collate_fn,)trainer.train()

使用kaggle的免费gpu T4*2(双倍时间消耗累计),gpu基本100%,gpu是一周30hhrs免费时间的,我为了节省时间,用2epoch,batch_size=128,因为kaggle的session有效期在12hours内,越快越好,否则session断开就白训练了,简单看下效果,大概1个小时左右,也可以save session让他在后台跑。
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (2)
看下速度
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (3)
第一次epoch完成
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (4)
查看输出
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (5)
点击后面的三个点下载所有的文件,然后将模型下载下来,点击输入的上传-new model
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (6)
输入model名称,选择私有,点击create model
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (7)
输入平台:transformer,点击addnewvariation
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (8)
定义附件名称,选择协议,点击右下侧的create
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (9)
关闭后点击return to notebook,就可以看到输入的模型了,点击右侧的复制路径即可。
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (10)

这里input的是持久的不会丢失,output数据再页面关闭session关闭后就丢失,所以尽快保存下来,或者上传到huggingface。

预测

切换到p100(按分钟算,省钱)验证
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (11)
代码

model_name="/kaggle/input/image-classifity/transformers/version1/1/checkpoint-74" #复制输入的路径from peft import PeftConfig, PeftModelfrom transformers import AutoImageProcessor, AutoModelForImageClassificationfrom PIL import Imageimport requests,torchfrom datasets import load_datasetds = load_dataset("food101")labels = ds["train"].features["label"].nameslabel2id, id2label = dict(), dict()for i, label in enumerate(labels): label2id[label] = i id2label[i] = label config = PeftConfig.from_pretrained(model_name)model = AutoModelForImageClassification.from_pretrained( config.base_model_name_or_path,#google/vit-base-patch16-224-in21k label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True,)model = PeftModel.from_pretrained(model, model_name)url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/beignets.jpeg"image = Image.open(requests.get(url, stream=True).raw)print(image)image_processor = AutoImageProcessor.from_pretrained(config.base_model_name_or_path)encoding = image_processor(image.convert("RGB"), return_tensors="pt")with torch.no_grad(): outputs = model(**encoding) logits = outputs.logitspredicted_class_idx = logits.argmax(-1).item()print("Predicted class:", model.config.id2label[predicted_class_idx])

输出:beignets
Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (12)

Transformers实战03-PEFT库使用LORA方法微调。-CSDN博客 (2024)

References

Top Articles
Latest Posts
Article information

Author: Trent Wehner

Last Updated:

Views: 6210

Rating: 4.6 / 5 (76 voted)

Reviews: 83% of readers found this page helpful

Author information

Name: Trent Wehner

Birthday: 1993-03-14

Address: 872 Kevin Squares, New Codyville, AK 01785-0416

Phone: +18698800304764

Job: Senior Farming Developer

Hobby: Paintball, Calligraphy, Hunting, Flying disc, Lapidary, Rafting, Inline skating

Introduction: My name is Trent Wehner, I am a talented, brainy, zealous, light, funny, gleaming, attractive person who loves writing and wants to share my knowledge and understanding with you.