diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml new file mode 100644 index 0000000..0a72ccb --- /dev/null +++ b/.github/workflows/pages.yml @@ -0,0 +1,50 @@ +# Sample workflow for building and deploying a Jekyll site to GitHub Pages +name: Deploy Jekyll with GitHub Pages dependencies preinstalled + +on: + # Runs on pushes targeting the default branch + push: + branches: ["main"] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow one concurrent deployment +concurrency: + group: "pages" + cancel-in-progress: true + +jobs: + # Build job + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Pages + uses: actions/configure-pages@v2 + - name: Build with Jekyll + uses: actions/jekyll-build-pages@v1 + with: + source: ./ + destination: ./_site + - name: Upload artifact + uses: actions/upload-pages-artifact@v1 + + # Deployment job + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v1 \ No newline at end of file diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md deleted file mode 100644 index 7797047..0000000 --- a/README.md +++ /dev/null @@ -1,181 +0,0 @@ -# 📱🦙 MobiLlama: Towards Accurate and Lightweight Fully Transparent GPT - - - -

- Oryx MobiLLama -

- -

- Oryx MobiLLama -

- -

- license -

- -#### [Omkar Thawakar](https://scholar.google.com/citations?user=flvl5YQAAAAJ&hl=en), [Ashmal Vayani](https://www.linkedin.com/in/ashmal-vayani/), [Salman Khan](https://salman-h-khan.github.io/), [Hisham Cholakkal](https://scholar.google.com/citations?hl=en&user=bZ3YBRcAAAAJ), [Rao Muhammad Anwer](https://scholar.google.com/citations?hl=en&authuser=1&user=_KlvMVoAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?user=lkWfR08AAAAJ&hl=en), [Timothy Baldwin](https://scholar.google.com/citations?user=wjBD1dkAAAAJ&hl=en), [Eric Xing](https://scholar.google.com/citations?user=5pKTRxEAAAAJ&hl=en) and [Fahad Khan](https://sites.google.com/view/fahadkhans/home) - -#### **Mohamed Bin Zayed University of Artificial Intelligence (MBZUAI), UAE** - -[![paper](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/) -🤗 [![HuggingFace](https://img.shields.io/badge/HuggingFace-Page-F9D371)](https://huggingface.co/collections/MBZUAI/mobillama-65dd4182d588c91e8230332e) -[![Demo](https://img.shields.io/badge/Gradio-Demo-red)](https://845b645234785da51b.gradio.live/) - ---- - -## 📢 Latest Updates -- **Feb-26-24**- Arxiv Preprint is released! -- **Feb-25-24**- Code (Training and Evaluation scripts) is released! -- **Feb-25-24**- Final pre-trained models (including intermediate checkpoints) and chat version along with online demo links released! - - -## Overview - -`Bigger the better` has been the predominant trend in recent Large Language Models (LLMs) development. -However, LLMs do not suit well for scenarios that require on-device processing, energy efficiency, low memory footprint, and response efficiency. These requisites are crucial for privacy, security, and sustainable deployment. -This paper explores the `less is more` paradigm by addressing the challenge of designing accurate yet efficient Small Language Models (SLMs) for resource constrained devices. -Our primary contribution is the introduction of an accurate and fully transparent open-source 0.5 billion (0.5B) parameter SLM, named `MobiLlama`, catering to the specific needs of resource-constrained computing with an emphasis on enhanced performance with reduced resource demands. -`MobiLlama` is a SLM design that initiates from a larger model and applies a careful parameter sharing scheme to reduce both the pre-training and the deployment cost. - -## ⚡ Model Download - -| Model Name | Link Download | -|-----------------------------------------------------|----------------------------------------------------------------------| -| MobiLlama-05B | [HuggingFace](https://huggingface.co/MBZUAI/MobiLlama-05B) | -| MobiLlama-08B | [HuggingFace](https://huggingface.co/MBZUAI/MobiLlama-08B) | -| MobiLlama-1B | [HuggingFace](https://huggingface.co/MBZUAI/MobiLlama-1B) | -| MobiLlama-05B-Chat | [HuggingFace](https://huggingface.co/MBZUAI/MobiLlama-05B-Chat) | -| MobiLlama-1B-Chat | [HuggingFace](https://huggingface.co/MBZUAI/MobiLlama-1B-Chat) | - - -## Generation with MobiLlama - -

- -

- -## Model Description - -- **Model type:** Language model designed using the architecture of LLaMA-7B -- **Language(s) (NLP):** English -- **License:** Apache 2.0 -- **Resources for more information:** - - [Training Code](https://github.com/mbzuai-oryx/MobiLlama) - - [Data Preparation](https://github.com/LLM360/amber-data-prep) - - [Metrics]() - - [Fully processed Amber pretraining data](https://huggingface.co/datasets/LLM360/AmberDatasets) - - -# Loading MobiLlama - -```python -from .model_utils.modeling_mobillama import LlamaTokenizer, LlamaForCausalLM - -tokenizer = LlamaTokenizer.from_pretrained("MBZUAI/MobiLlama-05B") -model = LlamaForCausalLM.from_pretrained("MBZUAI/MobiLlama-05B") - -input_text = "translate English to German: How old are you?" -input_ids = tokenizer(input_text, return_tensors="pt").input_ids - -outputs = model.generate(input_ids) -print(tokenizer.decode(outputs[0])) -``` - -## Dataset - -Download the preprocessed Amber data from [huggingface](https://huggingface.co/datasets/LLM360/AmberDatasets). The entire training data has 360 chunks totalling the size of ~8 TB. Amber dataset contains total 1.2 Trillion tokens with gathered from different data sources shown below. - -| Subset | Tokens (Billion) | -| ----------- | ----------- | -| Arxiv | 30.00 | -| Book | 28.86 | -| C4 | 197.67 | -| Refined-Web | 665.01 | -| StarCoder | 291.92 | -| StackExchange | 21.75 | -| Wikipedia | 23.90 | -| Total | 1259.13 | - -## Installation - -First install [PyTorch](https://pytorch.org) according to the instructions specific to your operating system. - -To install from source (recommended for training/fine-tuning) run: - -```bash -conda create -n mobillama python=3.10 -conda activate mibillama -git clone https://github.com/mbzuai-oryx/MobiLlama.git -cd MobiLlama -pip install -r requirements.txt -``` - -## pretrain -For MobiLlama (using 20 nodes of A100 80GB GPUS) -```bash -sbatch pretrain.sh -``` -For `large-base` use main_largebase.py in L:11 of pretrain.sh - -## 🔎 Evaluation - -We used [Analysis-360](https://github.com/LLM360/Analysis360) to evaluate our model on different llm benchmarks. - - - -## 📊 Results - -| Model Name | #Params | HellaSwag | Truthfulqa | MMLU | Arc_C | CrowsPairs | piqa | race | siqa | winogrande | Average | -|--------------------|---------|-----------|------------|-------|-------|------------|-------|-------|-------|------------|---------| -| gpt-neo-125m | 0.15B | 30.26 | 45.58 | 25.97 | 22.95 | 61.55 | 62.46 | 27.56 | 40.33 | 51.78 | 40.93 | -| tiny-starcoder | 0.17B | 28.17 | 47.68 | 26.79 | 20.99 | 49.68 | 52.55 | 25.45 | 38.28 | 51.22 | 37.86 | -| cerebras-gpt-256m | 0.26B | 28.99 | 45.98 | 26.83 | 22.01 | 60.52 | 61.42 | 27.46 | 40.53 | 52.49 | 40.69 | -| opt-350m | 0.35B | 36.73 | 40.83 | 26.02 | 23.55 | 64.12 | 64.74 | 29.85 | 41.55 | 52.64 | 42.22 | -| megatron-gpt2-345m | 0.38B | 39.18 | 41.51 | 24.32 | 24.23 | 64.82 | 66.87 | 31.19 | 40.28 | 52.96 | 42.81 | -| LiteLlama | 0.46B | 38.47 | 41.59 | 26.17 | 24.91 | 62.90 | 67.73 | 28.42 | 40.27 | 49.88 | 42.26 | -| gpt-sw3-356m | 0.47B | 37.05 | 42.55 | 25.93 | 23.63 | 61.59 | 64.85 | 32.15 | 41.56 | 53.04 | 42.48 | -| pythia-410m | 0.51B | 40.85 | 41.22 | 27.25 | 26.19 | 64.20 | 67.19 | 30.71 | 41.40 | 53.12 | 43.57 | -| xglm-564m | 0.56B | 34.64 | 40.43 | 25.18 | 24.57 | 62.25 | 64.85 | 29.28 | 42.68 | 53.03 | 41.87 | -| Lamini-GPT-LM | 0.59B | 31.55 | 40.72 | 25.53 | 24.23 | 63.09 | 63.87 | 29.95 | 40.78 | 47.75 | 40.83 | -| **MobiLlama (Ours)** | **0.5B** | **52.52** | **38.05** | **26.45**| **29.52**| **64.03** | **72.03**| **33.68**| **40.22**| **57.53** | **46.00** | -| Lamini-GPT-LM | 0.77B | 43.83 | 40.25 | 26.24 | 27.55 | 66.12 | 69.31 | 37.12 | 42.47 | 56.59 | 45.49 | -| **MobiLlama (Ours)** | **0.8B** | **54.09** | **38.48** | **26.92** | **30.20** | **64.82** | **73.17** | **33.37** | **41.60** | **57.45** | **46.67** | - -`The table provides a comparative analysis of various models, including our MobiLlama, across several LLM benchmarks. It highlights MobiLlama's superior performance, particularly in its 0.5B and 0.8B configurations, showcasing its efficiency and effectiveness in processing complex language tasks. This comparison underscores MobiLlama's advancements in achieving higher accuracy and demonstrates its potential as a leading solution in the field of LLM.` - ---- - -| Model | #Params | HellaSwag | Truthfulqa | MMLU | Arc_C | CrowsPairs | piqa | race | siqa | winogrande | Average | -|---------------|---------|-----------|------------|------|-------|------------|------|------|------|------------|---------| -| Boomer | 1B | 31.62 | 39.42 | 25.42| 22.26 | 61.26 | 57.99| 28.99| 40.32| 50.98 | 39.80 | -| Pythia-Dedup | 1B | 49.63 | 38.92 | 24.29| 29.09 | 67.11 | 70.23| 32.44| 42.63| 53.98 | 45.36 | -| Falcon-RW | 1B | 63.12 | 35.96 | 25.36| 35.06 | 69.04 | 74.10| 36.07| 40.23| 61.88 | 48.98 | -| TinyLlama | 1.1B | 60.22 | 37.59 | 26.11| 33.61 | 70.60 | 73.28| 36.45| 41.65| 59.18 | 48.74 | -| OLMo | 1.2B | 62.50 | 32.94 | 25.86| 34.45 | 69.59 | 73.70| 36.74| 41.14| 58.90 | 48.42 | -| Cerebras-GPT | 1.3B | 38.51 | 42.70 | 26.66| 26.10 | 63.67 | 66.75| 30.33| 42.42| 53.59 | 43.41 | -| Lamini | 1.3B | 38.05 | 36.43 | 28.47| 26.62 | 64.62 | 67.89| 33.39| 43.19| 50.59 | 43.25 | -| OPT | 1.3B | 54.50 | 38.67 | 24.63| 29.60 | 70.70 | 72.47| 34.16| 42.47| 59.74 | 47.43 | -| GPT-NEO | 1.3B | 48.49 | 39.61 | 24.82| 31.31 | 65.67 | 71.05| 34.06| 41.81| 57.06 | 45.98 | -| Pythia-Deduped| 1.4B | 55.00 | 38.63 | 25.45| 32.59 | 67.33 | 72.68| 34.64| 42.68| 56.90 | 47.32 | -| **large-base**| **1.2B**| **62.99** | **35.90** | **24.79**| **34.55** | **68.49** | **75.57**| **35.31**| **41.96**| **62.03** | **49.06** | - -`Comprehensive comparisons with existing < 2B params fully open-source LLM models on 9 benchmarks. Our 1.2B "large-base" model pre-trained on 1.2T tokens achieves superior performance compared to both the recent OLMo 1.17B model and TinyLlama 1.1B model, which are pre-trained on a substantially larger data of 3T tokens.` - -## 📱 MobiLlama on Android - -To run our model on an android app, please download and install the APK from [here](https://mbzuaiac-my.sharepoint.com/:f:/g/personal/omkar_thawakar_mbzuai_ac_ae/EhRfGdmgFVVNvIRfy1EgLwEBjbk_eg3UmNg_zjz7PMTsmg?e=NBuJo8). - -## 🙏 Acknowledgements - -+ We thank [LLM-360](https://github.com/LLM360/amber-train) for fully transparent and open-source implementation of their language model. MobiLlama repo is built using [LLM-360](https://github.com/LLM360/amber-train). - - -## 📜 Citation -```bibtex -@misc{thawakar2024mobillama, - title={MobiLlama: Towards Accurate and Lightweight Fully Transparent GPT}, - author={Omkar Thawakar and Ashmal Vayani and Salman Khan and Hisham Cholakkal and Rao Muhammad Anwer and Michael Felsberg and Timothy Baldwin and Eric P. Xing and Fahad Shahbaz Khan}, - year={2024} -} -``` \ No newline at end of file diff --git a/_config.yml b/_config.yml new file mode 100644 index 0000000..3f4749e --- /dev/null +++ b/_config.yml @@ -0,0 +1,3 @@ +theme: jekyll-theme-cayman +title: MobiLlama +description: An end-to-end Open Source Model that is designed to cater accurate yet efficient language model for resource-constrained devices.
MobiLlama mitigates the redundancy in transformer blocks by proposing a shared FFN design for all the transformer blocks within the Small Language Model. diff --git a/_layouts/default.html b/_layouts/default.html new file mode 100644 index 0000000..00772af --- /dev/null +++ b/_layouts/default.html @@ -0,0 +1,60 @@ + + + + + +{% seo %} + + + + + + + + {% include head-custom.html %} + + + Skip to the content. + + + +
+ {{ content }} + + +
+ + diff --git a/_sass/variables.scss b/_sass/variables.scss new file mode 100644 index 0000000..e06ee3f --- /dev/null +++ b/_sass/variables.scss @@ -0,0 +1,6 @@ +--- +--- + +$header-bg-color-secondary: #159957; +$header-bg-color: #155799; +@import "{{ site.theme }}"; diff --git a/assets/css/style.scss b/assets/css/style.scss new file mode 100644 index 0000000..2de9c11 --- /dev/null +++ b/assets/css/style.scss @@ -0,0 +1,34 @@ +--- +--- + +@import "{{ site.theme }}"; +ul.sticky { + list-style-type: none; + margin: 0; + padding: 0; + overflow: hidden; + background-color: #333; +} + +ul.sticky li { + float: left; +} + +ul.sticky li a { + display: block; + color: white; + text-align: center; + padding: 14px 16px; + text-decoration: none; +} + +ul.sticky li a:hover:not(.active) { + background-color: #111; +} + +.active { + background-color: #04AA6D; } + +table, tr, td{ + border:none; +} diff --git a/docs/Hardware_Comparision.png b/docs/Hardware_Comparision.png new file mode 100644 index 0000000..b9d295d Binary files /dev/null and b/docs/Hardware_Comparision.png differ diff --git a/docs/MobiLLaMA_video.mov b/docs/MobiLLaMA_video.mov new file mode 100644 index 0000000..e7ede5a Binary files /dev/null and b/docs/MobiLLaMA_video.mov differ diff --git a/images/MobileLLaMa.png b/docs/MobileLLaMa.png similarity index 100% rename from images/MobileLLaMa.png rename to docs/MobileLLaMa.png diff --git a/docs/Mobillama_Examples.png b/docs/Mobillama_Examples.png new file mode 100644 index 0000000..46a68d0 Binary files /dev/null and b/docs/Mobillama_Examples.png differ diff --git a/docs/Model_Comparisions.png b/docs/Model_Comparisions.png new file mode 100644 index 0000000..fa525e5 Binary files /dev/null and b/docs/Model_Comparisions.png differ diff --git a/docs/Teaser_Video.mp4 b/docs/Teaser_Video.mp4 new file mode 100644 index 0000000..d3f5a12 --- /dev/null +++ b/docs/Teaser_Video.mp4 @@ -0,0 +1 @@ + diff --git a/docs/VLM_Example.png b/docs/VLM_Example.png new file mode 100644 index 0000000..0f5f2c7 Binary files /dev/null and b/docs/VLM_Example.png differ diff --git a/images/littlellama_logo.png b/docs/littlellama_logo.png similarity index 100% rename from images/littlellama_logo.png rename to docs/littlellama_logo.png diff --git a/images/mobillama_generation.gif b/docs/mobillama_generation.gif similarity index 100% rename from images/mobillama_generation.gif rename to docs/mobillama_generation.gif diff --git a/docs/radar_plot.png b/docs/radar_plot.png new file mode 100644 index 0000000..20e26aa Binary files /dev/null and b/docs/radar_plot.png differ diff --git a/eval_utils/eval_hf_main.py b/eval_utils/eval_hf_main.py deleted file mode 100644 index 6f630ac..0000000 --- a/eval_utils/eval_hf_main.py +++ /dev/null @@ -1,76 +0,0 @@ -import time -import fire -import os -import glob - -CONFIGS = { - 'arc': { - 'tasks': 'arc_challenge', - 'n_shots': 25, - 'metric_name': 'acc_norm' - }, - 'hellaswag': { - 'tasks': 'hellaswag', - 'n_shots': 10, - 'metric_name': 'acc_norm' - }, - 'truthfulqa': { - 'tasks': 'truthfulqa_mc', - 'n_shots': 0, - 'metric_name': 'mc2' - }, - 'mmlu': { - 'tasks': 'hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions', - 'n_shots': 5, - 'metric_name': 'acc' - }, -} -BATCH_SIZE = 32 - - -def evaluate(config_name, model_dir, output_path): - config = CONFIGS[config_name] - tasks, n_shots, metric_name = \ - config['tasks'], config['n_shots'], config['metric_name'] - - batch_size = BATCH_SIZE - - while True: - command = f'python eval_utils/harness.py '\ - f'--model=hf-causal '\ - f'--model_args=\"pretrained={model_dir}\" '\ - f'--tasks={tasks} '\ - f'--num_fewshot={n_shots} '\ - f'--batch_size={batch_size} '\ - f'--output_path={output_path} '\ - f'--no_cache' - - if os.system(command) == 0: - break - else: - print(f'COMMAND \"{command}\" failed. rerunning...') - if batch_size > 1: - batch_size = batch_size // 2 - - -def main(workdir='workdir_7b'): - while True: - ckpt_dirs = glob.glob(f'{workdir}/ckpt_*') - ckpt_dirs.sort(key=lambda s: int(s[len(f'{workdir}/ckpt_'):])) - - for model_dir in ckpt_dirs: - for config_name in CONFIGS.keys(): - output_path = f'{model_dir}/eval_{config_name}.json' - if not os.path.exists(output_path): - print(f'evaluating {config_name} for {model_dir}...') - print('running...', file=open(output_path, 'w'), flush=True) - - time.sleep(60) - evaluate( - config_name=config_name, - model_dir=model_dir, - output_path=output_path) - - -if __name__ == '__main__': - fire.Fire(main) \ No newline at end of file diff --git a/eval_utils/harness.py b/eval_utils/harness.py deleted file mode 100644 index de9d11e..0000000 --- a/eval_utils/harness.py +++ /dev/null @@ -1,93 +0,0 @@ -import argparse -import json -import logging -import os - -from lm_eval import tasks, evaluator, utils - -logging.getLogger("openai").setLevel(logging.WARNING) - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", required=True) - parser.add_argument("--model_args", default="") - parser.add_argument("--tasks", default=None, choices=utils.MultiChoice(tasks.ALL_TASKS)) - parser.add_argument("--provide_description", action="store_true") - parser.add_argument("--num_fewshot", type=int, default=0) - parser.add_argument("--batch_size", type=str, default=None) - parser.add_argument("--max_batch_size", type=int, default=None, - help="Maximal batch size to try with --batch_size auto") - parser.add_argument("--device", type=str, default=None) - parser.add_argument("--output_path", default=None) - parser.add_argument("--limit", type=float, default=None, - help="Limit the number of examples per task. " - "If <1, limit is a percentage of the total number of examples.") - parser.add_argument("--data_sampling", type=float, default=None) - parser.add_argument("--no_cache", action="store_true") - parser.add_argument("--decontamination_ngrams_path", default=None) - parser.add_argument("--description_dict_path", default=None) - parser.add_argument("--check_integrity", action="store_true") - parser.add_argument("--write_out", action="store_true", default=False) - parser.add_argument("--output_base_path", type=str, default=None) - - return parser.parse_args() - - -def main(): - args = parse_args() - - assert not args.provide_description # not implemented - - if args.limit: - print( - "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." - ) - - if args.tasks is None: - task_names = tasks.ALL_TASKS - else: - task_names = utils.pattern_match(args.tasks.split(","), tasks.ALL_TASKS) - - print(f"Selected Tasks: {task_names}") - - description_dict = {} - if args.description_dict_path: - with open(args.description_dict_path, "r") as f: - description_dict = json.load(f) - - results = evaluator.simple_evaluate( - model=args.model, - model_args=args.model_args, - tasks=task_names, - num_fewshot=args.num_fewshot, - batch_size=args.batch_size, - max_batch_size=args.max_batch_size, - device=args.device, - no_cache=args.no_cache, - limit=args.limit, - description_dict=description_dict, - decontamination_ngrams_path=args.decontamination_ngrams_path, - check_integrity=args.check_integrity, - write_out=args.write_out, - output_base_path=args.output_base_path, - ) - - dumped = json.dumps(results, indent=2) - print(dumped) - - if args.output_path: - os.makedirs(os.path.dirname(args.output_path), exist_ok=True) - with open(args.output_path, "w") as f: - f.write(dumped) - - batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) - print( - f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " - f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" - ) - print(evaluator.make_table(results)) - - -if __name__ == "__main__": - main() diff --git a/index.md b/index.md new file mode 100644 index 0000000..2c7b6b1 --- /dev/null +++ b/index.md @@ -0,0 +1,62 @@ +
+ +
+ +
+ + +
+ +## Abstract +

+Bigger the better has been the predominant trend in recent Large Language Models (LLMs) development. However, LLMs do not suit well for scenarios that require on-device processing, energy efficiency, low memory footprint, and response efficiency. These requisites are crucial for privacy, security, and sustainable deployment. This paper explores the less is more paradigm by addressing the challenge of designing accurate yet efficient Small Language Models (SLMs) for resource constrained devices. Our primary contribution is the introduction of an accurate and fully transparent open-source 0.5 billion (0.5B) parameter SLM, named MobiLlama, catering to the specific needs of resource-constrained computing with an emphasis on enhanced performance with reduced resource demands. MobiLlama is a SLM design that initiates from a larger model and applies a careful parameter sharing scheme to reduce both the pre-training and the deployment cost.

+ + +## MobiLlama Architecture +![main figure](docs/mobillama_generation.gif) +

The proposed approach, MobiLlama, constructs a SLM of desired sizes (e.g., 0.5B model) by first initiating from a larger model size design, largebase. Then, we employ a careful parameter sharing scheme to reduce the model size to a pre-defined model configuration, thereby significantly reducing the training cost. Generally, both SLMs and LLMs typically utilize a dedicated multilayer perceptron (MLP) block comprising multiple feed forward network (FFN) layers within each transformer block. In such a configuration (e.g., large-base), the FFN layers account for a substantial 65% of the total trainable parameters, with attention mechanisms and heads contributing 30% and 5%, respectively. As a consequence, a significant number of parameters are concentrated within the FFN layers, thereby posing challenges during pre-training with respect to computational cost and the model’s ability to achieve faster convergence. To address these issues, we propose to use a sharing scheme where the FFN parameters are shared across all transformer layers within the SLM. This enables us to significantly reduce the overall trainable parameters by 60% in our MobiLlama, compared to the large-base. Such a significant parameter reduction also enables us to increase the model capacity in terms of number of layers and hidden dimension size without any substantial increase in the training cost.

+ + +## MobiLlama in comparison with existing <1B Models +![main figure](docs/Model_Comparisions.png) +

State-of-the-art comparisons with existing < 1B params models on nine benchmarks. In case of around 0.5B model series, our MobiLlama achieves a substantial gain of 2.4% in terms of average performance on nine benchmarks. Further, our MobiLlama 0.8B model achieves an average score of 46.67.

+ +
+

+

Comparison of our MobiLlama 0.5B and 0.8B models with recent OLMo-1.17B and TinyLlama-1.1B in terms of pre-training tokens, pre-training time and memory, model parameters, overall accuracy across nine benchmarks and on-device efficiency (average battery consumption and average token/second on a PC with RTX2080Ti). Our \emph{MobiLlama} achieves comparable accuracy while requiring significantly fewer pre-training data (1.2T tokens vs. 3T tokens), lesser pre-training time and GPU memory along with being efficient in terms of deployment on a resource constrained device.

+ + +## Qualitative Examples of MobiLlama +![main figure](docs/Mobillama_Examples.png) +

Example responses from our MobiLlama across a variety of tasks, including creative storytelling, coding exercises, economic analysis, and cooking instructions. The responses highlight the models’ ability to engage with both abstract concepts and practical, step-by-step processes, demonstrating its broad applicability.

+ + +## Qualitative Examples of MobiLlama-VLM +![main figure](docs/VLM_Example.png) +Example responses of MobiLlama-V in responding to visual stimuli across a range of scenarios. + + +## Hardware Platform Comparison +![main figure](docs/Hardware_Comparision.png) +

Comparison in terms of efficiency and resource consumption on different low-end hardware devices. We show the comparison on: a PC with RTX-2080Ti GPU, a laptop with i7 CPU and a smartphone with Snapdragon-685 processor. In addition to our large-base model, we also present the comparison with Llama2 7B and Phi2 2.7B. In case of CPU and smartphone, we use 4-bit GGUF format of the corresponding models, whereas the original models are deployed and tested on PC with RTX-2080Ti GPU. The different metrics measure the model’s operational efficiency, model’s footprint in the device’s RAM and the energy efficiency of processing 1,000 tokens. Our MobiLlama performs favorably in terms of efficiency on these low-end hardware devices. We note that both Phi2 and Llama2 are not fully transparent in that the complete data pipeline for pre-training is not publicly available.

+ + +## BibTeX +If you like our work, please consider citing us. +``` +@misc{thawakar2024mobillama, + title={MobiLlama: Towards Accurate and Lightweight Fully Transparent GPT}, + author={Omkar Thawakar and Ashmal Vayani and Salman Khan and Hisham Cholakkal and Rao Muhammad Anwer and Michael Felsberg and Timothy Baldwin and Eric P. Xing and Fahad Shahbaz Khan}, + year={2024}, + eprint={2402.16840}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/main_largebase.py b/main_largebase.py deleted file mode 100644 index f147656..0000000 --- a/main_largebase.py +++ /dev/null @@ -1,206 +0,0 @@ -from datetime import datetime -from pytz import timezone -import time -from functools import partial -import wandb -import os -import fire -import tqdm -import torch -torch.cuda.empty_cache() -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -import lightning as L -from lightning.fabric.strategies import FSDPStrategy, XLAStrategy -from transformers import AutoConfig, AutoTokenizer -from transformers import AutoTokenizer, AutoModelForCausalLM - -from model_utils.modeling_llama import LlamaForCausalLM, LlamaDecoderLayer, LlamaAttention - -from main_utils import ( - load_jsonl_examples, - get_cosine_lr_decay_fn, - get_grad_norm, - save_checkpoint, - get_last_ckpt_idx) - - -TIMEZONE = timezone('EST') -DATE = str(datetime.now(tz=TIMEZONE)).split()[0] -MODEL_SIZE = '1b' -PROJECT_NAME = f'llama_{MODEL_SIZE}' -RUN_NAME = f'pretraining_mobillama_{MODEL_SIZE}_{DATE}' -HF_MODEL_NAME_OR_PATH = f'mobillama' -WORKDIR = f'mobillama_{MODEL_SIZE}' - -LEARNING_RATE = 3e-4 -LR_SCHEDULE_TYPE = 'cosine' -END_LEARNING_RATE = 3e-5 -WARMUP_GRAD_STEPS = 2000 -GRAD_NORM_CLIP = 1. -WEIGHT_DECAY = 0.1 -BETA1 = 0.9 -BETA2 = 0.95 -ACCELERATOR = 'cuda' -PRECISION = 'bf16-mixed' -RANDOM_SEED = 11111 - -TRAIN_DATA_DIR = './Amber_Dataset/train' -TRAIN_EXAMPLES_PER_CHUNK = 1706976 -N_CHUNKS = 360 - - -def collate_fn(examples, device): - token_ids = torch.tensor( - [example['token_ids'] for example in examples], device=device) - return {'input_ids': token_ids[:, :-1], 'labels': token_ids[:, 1:]} - - -def train_chunk(fabric, - tokenizer, - model, - optimizer, - lr_schedule_fn, - examples, - per_device_batch_size, - accumulate_grad_batches, - chunk_idx, - run_wandb): - step = chunk_idx * (len(examples) // per_device_batch_size) - - example_batch_idxes = tqdm.trange( - 0, len(examples), per_device_batch_size, - desc=f'Training chunk {chunk_idx} (global_micro_batch_size=' - f'{per_device_batch_size * fabric.world_size}, ' - f'accumulate_grad_batches={accumulate_grad_batches})') - for i in example_batch_idxes: - t0 = time.time() - - lr = lr_schedule_fn(step) - step += 1 - for param_group in optimizer.param_groups: - param_group["lr"] = lr - is_accumulating = (step % accumulate_grad_batches != 0) - - batch = collate_fn( - examples=examples[i:i+per_device_batch_size], device=fabric.device) - input_ids, labels = batch['input_ids'], batch['labels'] - with fabric.no_backward_sync(model, enabled=is_accumulating): - logits = model(input_ids).logits - loss = torch.nn.functional.cross_entropy( - logits.reshape((-1, logits.size(-1))), labels.reshape(-1)) - - fabric.backward(loss / accumulate_grad_batches) - - if not is_accumulating: - grad_norm = get_grad_norm(model=model) - fabric.clip_gradients(model, optimizer, max_norm=GRAD_NORM_CLIP) - optimizer.step() - optimizer.zero_grad() - - log = { - 'loss': loss.item(), - 'learning_rate': lr, - 'step': step, - 'speed(#tok/s/gpu)': int(input_ids.numel() / (time.time() - t0)) - } - if not is_accumulating: - log['grad_norm'] = grad_norm - - example_batch_idxes.set_postfix(log) - if run_wandb and fabric.global_rank == 0: - wandb.log(log) - - save_checkpoint( - fabric=fabric, - tokenizer=tokenizer, - model=model, - optimizer=optimizer, - save_dir=f'{WORKDIR}/ckpt_{chunk_idx}') - - -def main(n_nodes=1, - n_devices_per_node=8, - per_device_batch_size=8, - accumulate_grad_batches=1, - run_wandb=False): - fabric = L.Fabric( - accelerator=ACCELERATOR, - num_nodes=n_nodes, - devices=n_devices_per_node, - precision=PRECISION, - strategy=FSDPStrategy( - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls={LlamaDecoderLayer}), - activation_checkpointing_policy={LlamaDecoderLayer}, - cpu_offload=True, - limit_all_gathers=True)) - # strategy = XLAStrategy(sync_module_states=False)) - fabric.launch() - - if fabric.global_rank == 0: - os.makedirs(WORKDIR, exist_ok=True) - if run_wandb: - wandb.init(project=PROJECT_NAME, name=RUN_NAME) - - last_ckpt_idx = get_last_ckpt_idx(workdir=WORKDIR) - fabric.seed_everything(RANDOM_SEED + last_ckpt_idx + 1) - - tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME_OR_PATH) - model = LlamaForCausalLM( - config=AutoConfig.from_pretrained(HF_MODEL_NAME_OR_PATH)) - - - # print(model) - print("="*50) - print("Model params : ", model.num_parameters()) - print("="*50) - - - optimizer = torch.optim.AdamW( - model.parameters(), - lr=LEARNING_RATE, - weight_decay=WEIGHT_DECAY, - betas=(BETA1, BETA2), - foreach=False) - - model, optimizer = fabric.setup(model, optimizer) - if last_ckpt_idx != -1: - fabric.load( - path=f'{WORKDIR}/ckpt_{last_ckpt_idx}/fabric_ckpt', - state={'model': model, 'optimizer': optimizer}) - - torch.cuda.empty_cache() - - global_micro_batch_size = per_device_batch_size * fabric.world_size - total_steps = TRAIN_EXAMPLES_PER_CHUNK // global_micro_batch_size * N_CHUNKS - lr_schedule_fn = get_cosine_lr_decay_fn( - total_steps=total_steps, - warmup_steps=WARMUP_GRAD_STEPS * accumulate_grad_batches, - learning_rate=LEARNING_RATE, - end_learning_rate=END_LEARNING_RATE) - - for chunk_idx in range(last_ckpt_idx + 1, N_CHUNKS): - examples = load_jsonl_examples( - filename=f'{TRAIN_DATA_DIR}/train_{chunk_idx:03d}.jsonl', - n_examples=TRAIN_EXAMPLES_PER_CHUNK, - shuffle=True, - global_micro_batch_size=global_micro_batch_size, - global_rank=fabric.global_rank, - world_size=fabric.world_size) - - train_chunk( - fabric=fabric, - tokenizer=tokenizer, - model=model, - optimizer=optimizer, - lr_schedule_fn=lr_schedule_fn, - examples=examples, - per_device_batch_size=per_device_batch_size, - accumulate_grad_batches=accumulate_grad_batches, - chunk_idx=chunk_idx, - run_wandb=run_wandb) - - -if __name__ == '__main__': - fire.Fire(main) \ No newline at end of file diff --git a/main_mobillama.py b/main_mobillama.py deleted file mode 100644 index 84af271..0000000 --- a/main_mobillama.py +++ /dev/null @@ -1,208 +0,0 @@ -from datetime import datetime -from pytz import timezone -import time -from functools import partial -import wandb -import os -import fire -import tqdm -import torch -torch.cuda.empty_cache() -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -import lightning as L -from lightning.fabric.strategies import FSDPStrategy -from transformers import AutoConfig, AutoTokenizer -from transformers import AutoTokenizer, AutoModelForCausalLM - -from model_utils.modeling_mobillama import LlamaForCausalLM, LlamaDecoderLayer, LlamaAttention - -from main_mobillama_utils import ( - load_jsonl_examples, - get_cosine_lr_decay_fn, - get_grad_norm, - save_checkpoint, - get_last_ckpt_idx) - - -TIMEZONE = timezone('EST') -DATE = str(datetime.now(tz=TIMEZONE)).split()[0] -MODEL_SIZE = '05b' -PROJECT_NAME = f'mobillama_{MODEL_SIZE}' -RUN_NAME = f'pretraining_mobillama-SFFN_{MODEL_SIZE}_{DATE}' -HF_MODEL_NAME_OR_PATH = f'mobillama' -WORKDIR = f'mobillama_{MODEL_SIZE}' - -LEARNING_RATE = 3e-5 -LR_SCHEDULE_TYPE = 'cosine' -END_LEARNING_RATE = 3e-6 -WARMUP_GRAD_STEPS = 2000 -GRAD_NORM_CLIP = 1. -WEIGHT_DECAY = 0.1 -BETA1 = 0.9 -BETA2 = 0.95 -ACCELERATOR = 'cuda' -PRECISION = 'bf16-mixed' -RANDOM_SEED = 11111 - -TRAIN_DATA_DIR = './Amber_Dataset/train' -TRAIN_EXAMPLES_PER_CHUNK = 1706976 -N_CHUNKS = 360 - - -def collate_fn(examples, device): - token_ids = torch.tensor( - [example['token_ids'] for example in examples], device=device) - return {'input_ids': token_ids[:, :-1], 'labels': token_ids[:, 1:]} - - -def train_chunk(fabric, - tokenizer, - model, - optimizer, - lr_schedule_fn, - examples, - per_device_batch_size, - accumulate_grad_batches, - chunk_idx, - run_wandb): - step = chunk_idx * (len(examples) // per_device_batch_size) - - example_batch_idxes = tqdm.trange( - 0, len(examples), per_device_batch_size, - desc=f'Training chunk {chunk_idx} (global_micro_batch_size=' - f'{per_device_batch_size * fabric.world_size}, ' - f'accumulate_grad_batches={accumulate_grad_batches})') - for i in example_batch_idxes: - t0 = time.time() - - lr = lr_schedule_fn(step) - step += 1 - for param_group in optimizer.param_groups: - param_group["lr"] = lr - is_accumulating = (step % accumulate_grad_batches != 0) - - batch = collate_fn( - examples=examples[i:i+per_device_batch_size], device=fabric.device) - input_ids, labels = batch['input_ids'], batch['labels'] - with fabric.no_backward_sync(model, enabled=is_accumulating): - logits = model(input_ids).logits - loss = torch.nn.functional.cross_entropy( - logits.reshape((-1, logits.size(-1))), labels.reshape(-1)) - - fabric.backward(loss / accumulate_grad_batches) - - if not is_accumulating: - grad_norm = get_grad_norm(model=model) - fabric.clip_gradients(model, optimizer, max_norm=GRAD_NORM_CLIP) - optimizer.step() - optimizer.zero_grad() - - log = { - 'loss': loss.item(), - 'learning_rate': lr, - 'step': step, - 'speed(#tok/s/gpu)': int(input_ids.numel() / (time.time() - t0)) - } - if not is_accumulating: - log['grad_norm'] = grad_norm - - example_batch_idxes.set_postfix(log) - if run_wandb and fabric.global_rank == 0: - wandb.log(log) - - save_checkpoint( - fabric=fabric, - tokenizer=tokenizer, - model=model, - optimizer=optimizer, - save_dir=f'{WORKDIR}/ckpt_{chunk_idx}') - - -def main(n_nodes=1, - n_devices_per_node=8, - per_device_batch_size=8, - accumulate_grad_batches=1, - run_wandb=False): - fabric = L.Fabric( - accelerator=ACCELERATOR, - num_nodes=n_nodes, - devices=n_devices_per_node, - precision=PRECISION, - strategy=FSDPStrategy( - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls={LlamaAttention}), - activation_checkpointing_policy={LlamaAttention}, - cpu_offload=True, - limit_all_gathers=True)) - fabric.launch() - - - if fabric.global_rank == 0: - os.makedirs(WORKDIR, exist_ok=True) - if run_wandb: - wandb.init(project=PROJECT_NAME, name=RUN_NAME) - - last_ckpt_idx = get_last_ckpt_idx(workdir=WORKDIR) - fabric.seed_everything(RANDOM_SEED + last_ckpt_idx + 1) - - tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME_OR_PATH) - model = LlamaForCausalLM( - config=AutoConfig.from_pretrained(HF_MODEL_NAME_OR_PATH)) - - - # print(model) - print("="*50) - print(model) - print("="*50) - print("Model params : ", model.num_parameters()) - print("="*50) - - - optimizer = torch.optim.AdamW( - model.parameters(), - lr=LEARNING_RATE, - weight_decay=WEIGHT_DECAY, - betas=(BETA1, BETA2), - foreach=False) - - model, optimizer = fabric.setup(model, optimizer) - if last_ckpt_idx != -1: - fabric.load( - path=f'{WORKDIR}/ckpt_{last_ckpt_idx}/fabric_ckpt', - state={'model': model, 'optimizer': optimizer}) - - torch.cuda.empty_cache() - - global_micro_batch_size = per_device_batch_size * fabric.world_size - total_steps = TRAIN_EXAMPLES_PER_CHUNK // global_micro_batch_size * N_CHUNKS - lr_schedule_fn = get_cosine_lr_decay_fn( - total_steps=total_steps, - warmup_steps=WARMUP_GRAD_STEPS * accumulate_grad_batches, - learning_rate=LEARNING_RATE, - end_learning_rate=END_LEARNING_RATE) - - for chunk_idx in range(last_ckpt_idx + 1, N_CHUNKS): - examples = load_jsonl_examples( - filename=f'{TRAIN_DATA_DIR}/train_{chunk_idx:03d}.jsonl', - n_examples=TRAIN_EXAMPLES_PER_CHUNK, - shuffle=True, - global_micro_batch_size=global_micro_batch_size, - global_rank=fabric.global_rank, - world_size=fabric.world_size) - - train_chunk( - fabric=fabric, - tokenizer=tokenizer, - model=model, - optimizer=optimizer, - lr_schedule_fn=lr_schedule_fn, - examples=examples, - per_device_batch_size=per_device_batch_size, - accumulate_grad_batches=accumulate_grad_batches, - chunk_idx=chunk_idx, - run_wandb=run_wandb) - - -if __name__ == '__main__': - fire.Fire(main) \ No newline at end of file diff --git a/main_mobillama_utils.py b/main_mobillama_utils.py deleted file mode 100644 index 88092bd..0000000 --- a/main_mobillama_utils.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import glob -import json -import tqdm -import math -import numpy as np -from torch.distributed.fsdp import FullStateDictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType -from lightning.fabric.strategies import FSDPStrategy - -from model_utils.modeling_mobillama import LlamaForCausalLM - - -def load_jsonl_examples(filename, - n_examples, - shuffle, - global_micro_batch_size, - global_rank, - world_size): - example_idxes = np.random.permutation(n_examples) if shuffle \ - else np.arange(n_examples) - - n_examples = n_examples // global_micro_batch_size * global_micro_batch_size - example_idxes = example_idxes[global_rank:n_examples:world_size] - - examples = {idx: None for idx in example_idxes} - for example_idx, line in tqdm.tqdm( - enumerate(open(filename)), desc=f'loading {filename}'): - if example_idx in examples: - examples[example_idx] = json.loads(line) - - return [examples[idx] for idx in example_idxes] - - -def get_cosine_lr_decay_fn(total_steps, - warmup_steps, - learning_rate, - end_learning_rate): - def cosine_with_warmup_lr(step): - if step < warmup_steps: - return learning_rate * step / warmup_steps - elif step > total_steps: - return end_learning_rate - - decay_ratio = (step - warmup_steps) / (total_steps - warmup_steps) - assert 0 <= decay_ratio <= 1 - coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) - return end_learning_rate + coeff * (learning_rate - end_learning_rate) - - return cosine_with_warmup_lr - - -def get_grad_norm(model): - square_sum = 0. - for param in model.parameters(): - if param.grad is not None: - square_sum += param.grad.detach().data.norm(2).item() ** 2 - return square_sum ** 0.5 - - -def save_checkpoint(fabric, tokenizer, model, optimizer, save_dir): - assert isinstance(fabric.strategy, FSDPStrategy) - - save_policy = FullStateDictConfig( - offload_to_cpu=(fabric.world_size > 1), rank0_only=True) - with FSDP.state_dict_type( - model, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=save_policy): - state_dict = model._forward_module.state_dict() - - if fabric.global_rank == 0: - tokenizer.save_pretrained(save_dir) - assert isinstance(model.module, LlamaForCausalLM) - model.module.save_pretrained( - save_dir, state_dict=state_dict, safe_serialization=False) - - fabric.barrier() - fabric.save( - path=f'{save_dir}/fabric_ckpt', - state={'model': model, 'optimizer': optimizer}) - - -def get_last_ckpt_idx(workdir): - last_ckpt_idx = -1 - for ckpt_dir in glob.glob(f'{workdir}/ckpt_*'): - ckpt_idx = int(ckpt_dir.split('_')[-1]) - if ckpt_idx > last_ckpt_idx: - last_ckpt_idx = ckpt_idx - - return last_ckpt_idx \ No newline at end of file diff --git a/main_utils.py b/main_utils.py deleted file mode 100644 index 67734bf..0000000 --- a/main_utils.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import glob -import json -import tqdm -import math -import numpy as np -from torch.distributed.fsdp import FullStateDictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType -from lightning.fabric.strategies import FSDPStrategy - -from model_utils.modeling_llama import LlamaForCausalLM - - -def load_jsonl_examples(filename, - n_examples, - shuffle, - global_micro_batch_size, - global_rank, - world_size): - example_idxes = np.random.permutation(n_examples) if shuffle \ - else np.arange(n_examples) - - n_examples = n_examples // global_micro_batch_size * global_micro_batch_size - example_idxes = example_idxes[global_rank:n_examples:world_size] - - examples = {idx: None for idx in example_idxes} - for example_idx, line in tqdm.tqdm( - enumerate(open(filename)), desc=f'loading {filename}'): - if example_idx in examples: - examples[example_idx] = json.loads(line) - - return [examples[idx] for idx in example_idxes] - - -def get_cosine_lr_decay_fn(total_steps, - warmup_steps, - learning_rate, - end_learning_rate): - def cosine_with_warmup_lr(step): - if step < warmup_steps: - return learning_rate * step / warmup_steps - elif step > total_steps: - return end_learning_rate - - decay_ratio = (step - warmup_steps) / (total_steps - warmup_steps) - assert 0 <= decay_ratio <= 1 - coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) - return end_learning_rate + coeff * (learning_rate - end_learning_rate) - - return cosine_with_warmup_lr - - -def get_grad_norm(model): - square_sum = 0. - for param in model.parameters(): - if param.grad is not None: - square_sum += param.grad.detach().data.norm(2).item() ** 2 - return square_sum ** 0.5 - - -def save_checkpoint(fabric, tokenizer, model, optimizer, save_dir): - assert isinstance(fabric.strategy, FSDPStrategy) - - save_policy = FullStateDictConfig( - offload_to_cpu=(fabric.world_size > 1), rank0_only=True) - with FSDP.state_dict_type( - model, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=save_policy): - state_dict = model._forward_module.state_dict() - - if fabric.global_rank == 0: - tokenizer.save_pretrained(save_dir) - assert isinstance(model.module, LlamaForCausalLM) - model.module.save_pretrained( - save_dir, state_dict=state_dict, safe_serialization=False) - - fabric.barrier() - fabric.save( - path=f'{save_dir}/fabric_ckpt', - state={'model': model, 'optimizer': optimizer}) - - -def get_last_ckpt_idx(workdir): - last_ckpt_idx = -1 - for ckpt_dir in glob.glob(f'{workdir}/ckpt_*'): - ckpt_idx = int(ckpt_dir.split('_')[-1]) - if ckpt_idx > last_ckpt_idx: - last_ckpt_idx = ckpt_idx - - return last_ckpt_idx \ No newline at end of file diff --git a/mobillama/config.json b/mobillama/config.json deleted file mode 100644 index 04aef88..0000000 --- a/mobillama/config.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 5632, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 8, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "tie_word_embeddings": false, - "torch_dtype": "float32", - "transformers_version": "4.31.0.dev0", - "use_cache": true, - "vocab_size": 32000 - } - - \ No newline at end of file diff --git a/mobillama/config_08b.json b/mobillama/config_08b.json deleted file mode 100644 index 7001183..0000000 --- a/mobillama/config_08b.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 2560, - "initializer_range": 0.02, - "intermediate_size": 10240, - "max_position_embeddings": 2560, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 22, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": false, - "torch_dtype": "float32", - "transformers_version": "4.36.1", - "use_cache": true, - "vocab_size": 32000 -} diff --git a/mobillama/config_largebase.json b/mobillama/config_largebase.json deleted file mode 100644 index 04aef88..0000000 --- a/mobillama/config_largebase.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 5632, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 8, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "tie_word_embeddings": false, - "torch_dtype": "float32", - "transformers_version": "4.31.0.dev0", - "use_cache": true, - "vocab_size": 32000 - } - - \ No newline at end of file diff --git a/model_utils/modeling_llama.py b/model_utils/modeling_llama.py deleted file mode 100644 index 88bf8dd..0000000 --- a/model_utils/modeling_llama.py +++ /dev/null @@ -1,896 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from transformers.models.llama.configuration_llama import LlamaConfig - -from flash_attn import flash_attn_func - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - return (self.weight * hidden_states).to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - # - # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - # raise ValueError( - # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - # f" {attn_weights.size()}" - # ) - # - # if attention_mask is not None: - # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - # raise ValueError( - # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - # ) - # attn_weights = attn_weights + attention_mask - # attn_weights = torch.max( - # attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) - # ) - # - # # upcast attention to fp32 - # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - # attn_output = torch.matmul(attn_weights, value_states) - # - # if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - # raise ValueError( - # f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - # f" {attn_output.size()}" - # ) - # - # attn_output = attn_output.transpose(1, 2) - - attn_output = flash_attn_func( - q=query_states.transpose(1, 2).to(torch.bfloat16), - k=key_states.transpose(1, 2).to(torch.bfloat16), - v=value_states.transpose(1, 2).to(torch.bfloat16), - causal=True) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = attn_output.to(query_states.dtype) - - attn_output = self.o_proj(attn_output) - - # if not output_attentions: - # attn_weights = None - assert not output_attentions - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/model_utils/modeling_mobillama.py b/model_utils/modeling_mobillama.py deleted file mode 100644 index 45e9447..0000000 --- a/model_utils/modeling_mobillama.py +++ /dev/null @@ -1,897 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from transformers.models.llama.configuration_llama import LlamaConfig - -from flash_attn import flash_attn_func - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - return (self.weight * hidden_states).to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - # - # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - # raise ValueError( - # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - # f" {attn_weights.size()}" - # ) - # - # if attention_mask is not None: - # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - # raise ValueError( - # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - # ) - # attn_weights = attn_weights + attention_mask - # attn_weights = torch.max( - # attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) - # ) - # - # # upcast attention to fp32 - # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - # attn_output = torch.matmul(attn_weights, value_states) - # - # if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - # raise ValueError( - # f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - # f" {attn_output.size()}" - # ) - # - # attn_output = attn_output.transpose(1, 2) - - attn_output = flash_attn_func( - q=query_states.transpose(1, 2).to(torch.bfloat16), - k=key_states.transpose(1, 2).to(torch.bfloat16), - v=value_states.transpose(1, 2).to(torch.bfloat16), - causal=True) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = attn_output.to(query_states.dtype) - - attn_output = self.o_proj(attn_output) - - # if not output_attentions: - # attn_weights = None - assert not output_attentions - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, mlp): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = mlp #LlamaMLP(hidden_size=self.hidden_size,intermediate_size=config.intermediate_size,hidden_act=config.hidden_act,) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - mlp = LlamaMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.layers = nn.ModuleList([LlamaDecoderLayer(config, mlp) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/pretrain.sh b/pretrain.sh deleted file mode 100644 index 67b4a6d..0000000 --- a/pretrain.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/sh -#SBATCH --job-name=mobillama -#SBATCH --account -#SBATCH --partition= -#SBATCH --nodes=20 -#SBATCH --ntasks-per-node=8 -#SBATCH --cpus-per-task=14 -#SBATCH --gres=gpu:8 -#SBATCH -t 3-00:00:00 - -srun python main_mobillama.py --n_nodes 20 --run_wandb \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 452e1a4..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -torch>=2.1.1 -lightning>=2.1.2 -flash_attn>=2.3.3 -transformers>=4.36.2 -fire -wandb \ No newline at end of file