First commit
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
*.bk
|
||||
__pycache__
|
||||
.DS_Store
|
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2024 OpenBMB
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
371
README.md
Normal file
@ -0,0 +1,371 @@
|
||||
<div align="center">
|
||||
|
||||
<!-- <!-- <h1 style="color: #33A6B8; font-family: Helvetica"> OmniLMM </h1> -->
|
||||
|
||||
<img src="./assets/title-2.png" width="200em" ></img>
|
||||
|
||||
**Large multi-modal models for strong performance and efficient deployment**
|
||||
<p align="center">
|
||||
OmniLMM-3B <a href="https://huggingface.co/openbmb/MiniCPM-V/">🤗</a> <a href="http://120.92.209.146:80/">🤖</a> |
|
||||
OmniLMM-12B <a href="https://huggingface.co/openbmb/OmniLMM-12B/">🤗</a> <a href="http://120.92.209.146:8081">🤖</a>
|
||||
</p>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
**OmniLMM** is a family of open-source large multimodal models (LMMs) adept at vision & language modeling. The model processes images and text inputs and delivers high-quality text outputs. We release two featured versions of OmniLMM that are targeted at **strong performance and efficient deployment**:
|
||||
|
||||
- **OmniLMM-12B**: Leading performance among comparable-sized models on multiple benchmarks.
|
||||
|
||||
- **OmniLMM-3B**: Frontier end device multi-modal conversation with promising performance.
|
||||
|
||||
|
||||
[中文文档](./README_zh.md)
|
||||
|
||||
## Contents
|
||||
- [OmniLMM-12B](#omnilmm-12b)
|
||||
- [OmniLMM-3B](#omnilmm-3b)
|
||||
- [Demo](#demo)
|
||||
- [Install](#install)
|
||||
- [Inference](#inference)
|
||||
- [Model Zoo](#model-zoo)
|
||||
|
||||
## OmniLMM-12B
|
||||
**OmniLMM-12B** is the most capable version. The model is built based on EVA02-5B and Zephyr-7B-β, connected with a perceiver resampler layer, and trained on multimodal data in a curriculum fashion. The model has three notable features:
|
||||
|
||||
- 🔥 **Strong Performance.**
|
||||
|
||||
OmniLMM-12B achieves **leading performance** among models with comparable sizes, surpassing established LMMs on multiple benchmarks (including MME, MMBench, SEED-Bench, etc). The model also **supports OCR capability** and endows **rich multimodal world knowledge**.
|
||||
|
||||
- 🏆 **Trustworthy Behavior.**
|
||||
|
||||
LMMs are known for suffering from hallucination, often generating text that is not factually grounded in images (e.g., faithfully describing non-existing objects in images). OmniLMM-12B is **the first state-of-the-art open-source LMM aligned via multimodal RLHF for trustworthy behavior** (using our recent [RLHF-V](https://rlhf-v.github.io/) technique). It **ranks #1** among open-source models on [MMHal-Bench](https://huggingface.co/datasets/Shengcao1006/MMHal-Bench), and **outperforms GPT-4V** on [Object HalBench](https://arxiv.org/abs/2312.00849).
|
||||
|
||||
- 🕹 **Real-time Multimodal Interaction.**
|
||||
|
||||
We combine the OmniLMM-12B and GPT-3.5 into a **real-time multimodal interactive assistant**. The assistant accepts video streams from the camera and speech streams from the microphone and emits speech output. While still primary, we find the model can **replicate some of the fun cases shown in the Gemini Demo video, without any video edition**.
|
||||
|
||||
### Evaluation
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>MME</th>
|
||||
<th nowrap="nowrap">MMB dev (en)</th>
|
||||
<th nowrap="nowrap" >MMMU val</th>
|
||||
<th nowrap="nowrap" >MMHal-Bench</th>
|
||||
<th nowrap="nowrap" >Object HalBench</th>
|
||||
<th nowrap="nowrap" >SeedBench-I</th>
|
||||
<th>MathVista</th>
|
||||
<th nowrap="nowrap" >LLaVA Bench W</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td align="left">GPT-4V†</td>
|
||||
<td>-</td>
|
||||
<td>1409</td>
|
||||
<td>75.1 </td>
|
||||
<td>56.8</td>
|
||||
<td>3.53 / 70.8</td>
|
||||
<td>86.4 / 92.7</td>
|
||||
<td>71.6 </td>
|
||||
<td>47.8 </td>
|
||||
<td>93.1 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Qwen-VL-Plus†</td>
|
||||
<td>-</td>
|
||||
<td>1681</td>
|
||||
<td>66.2 </td>
|
||||
<td>45.2</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>65.7 </td>
|
||||
<td>36.0 </td>
|
||||
<td>73.7 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">Yi-VL 6B</td>
|
||||
<td align="right">6.7B </td>
|
||||
<td>- </td>
|
||||
<td>68.2 </td>
|
||||
<td>39.1 </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>66.1 </td>
|
||||
<td>28.0 </td>
|
||||
<td>39.9 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >Qwen-VL-Chat</td>
|
||||
<td align="right">9.6B</td>
|
||||
<td>1488</td>
|
||||
<td>60.6 </td>
|
||||
<td>35.9</td>
|
||||
<td>2.93 / 59.4</td>
|
||||
<td>56.2 / 80.0</td>
|
||||
<td>64.8 </td>
|
||||
<td>33.8 </td>
|
||||
<td>67.7 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" >CogVLM</td>
|
||||
<td align="right">17.4B</td>
|
||||
<td>1438</td>
|
||||
<td>63.7 </td>
|
||||
<td>32.1 </td>
|
||||
<td>2.68 / 52.1 </td>
|
||||
<td>73.6 / 87.4 </td>
|
||||
<td>68.8 </td>
|
||||
<td>34.7 </td>
|
||||
<td>73.9 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" >LLaVA 1.5</td>
|
||||
<td align="right">13.6B </td>
|
||||
<td>1531 </td>
|
||||
<td>68.2 </td>
|
||||
<td>36.4 </td>
|
||||
<td>2.71 / 51.0 </td>
|
||||
<td>53.7 / 77.4 </td>
|
||||
<td>68.1 </td>
|
||||
<td>26.4 </td>
|
||||
<td>64.6 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" ><b>OmniLMM-12B</b></td>
|
||||
<td align="right">11.6B </td>
|
||||
<td>1637 </td>
|
||||
<td>71.6 </td>
|
||||
<td>40.7 </td>
|
||||
<td>3.45 / 68.8 </td>
|
||||
<td>90.3 / 95.5 </td>
|
||||
<td>71.1 </td>
|
||||
<td>34.9 </td>
|
||||
<td>72.0 </td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
<small>†: Proprietary models</small>
|
||||
|
||||
### Examples
|
||||
|
||||
<table align="center" >
|
||||
<p align="center" >
|
||||
<img src="assets/omnilmm-12b-examples.png" />
|
||||
</p>
|
||||
</table>
|
||||
|
||||
<div align="center" >
|
||||
<video controls src="https://github.com/OpenBMB/OmniLMM/assets/157115220/c1fd3562-1ab1-4534-8139-79e9137b5398" type="video/mp4" />
|
||||
</div>
|
||||
|
||||
## OmniLMM-3B
|
||||
**OmniLMM-3B** (i.e., MiniCPM-V) is an efficient version with promising performance for deployment. The model is built based on SigLip-400M and [MiniCPM-2.4B](https://github.com/OpenBMB/MiniCPM/), connected by a perceiver resampler. Notable features of OmniLMM-3B include:
|
||||
|
||||
- ⚡️ **High Efficiency.**
|
||||
|
||||
OmniLMM-3B can be **efficiently deployed on most GPU cards and personal computers**, and **even on end devices such as mobile phones**. In terms of visual encoding, we compress the image representations into 64 tokens via a perceiver resampler, which is significantly fewer than other LMMs based on MLP architecture (typically > 512 tokens). This allows OmniLMM-3B to operate with **much less memory cost and higher speed during inference**.
|
||||
|
||||
- 🔥 **Promising Performance.**
|
||||
|
||||
OmniLMM-3B achieves **state-of-the-art performance** on multiple benchmarks (including MMMU, MME, and MMbech, etc) among models with comparable sizes, surpassing existing LMMs built on Phi-2. It even **achieves comparable or better performance than the 9.6B Qwen-VL-Chat**.
|
||||
|
||||
- 🙌 **Bilingual Support.**
|
||||
|
||||
OmniLMM-3B is **the first edge-deployable LMM supporting bilingual multimodal interaction in English and Chinese**. This is achieved by generalizing multimodal capabilities across languages, a technique from our ICLR 2024 spotlight [paper](https://arxiv.org/abs/2308.12038).
|
||||
|
||||
### Evaluation
|
||||
|
||||
<div align="center">
|
||||
|
||||
<table style="margin: 0px auto;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>MME</th>
|
||||
<th nowrap="nowrap" >MMB dev (en)</th>
|
||||
<th nowrap="nowrap" >MMB dev (zh)</th>
|
||||
<th nowrap="nowrap" >MMMU val</th>
|
||||
<th nowrap="nowrap" >CMMMU val</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td align="left">LLaVA-Phi</td>
|
||||
<td align="right">3B</td>
|
||||
<td>1335</td>
|
||||
<td>59.8</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">MobileVLM</td>
|
||||
<td align="right">3B</td>
|
||||
<td>1289</td>
|
||||
<td>59.6</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >Imp-v1</td>
|
||||
<td align="right">3B</td>
|
||||
<td>1434</td>
|
||||
<td>66.5</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" >Qwen-VL-Chat</td>
|
||||
<td align="right" >9.6B</td>
|
||||
<td>1487</td>
|
||||
<td>60.6 </td>
|
||||
<td>56.7 </td>
|
||||
<td>35.9 </td>
|
||||
<td>30.7 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >CogVLM</td>
|
||||
<td align="right">17.4B </td>
|
||||
<td>1438 </td>
|
||||
<td>63.7 </td>
|
||||
<td>53.8 </td>
|
||||
<td>32.1 </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" ><b>OmniLMM-3B</b></td>
|
||||
<td align="right">3B </td>
|
||||
<td>1452 </td>
|
||||
<td>67.3 </td>
|
||||
<td>61.9 </td>
|
||||
<td>34.7 </td>
|
||||
<td>32.1 </td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
</div>
|
||||
|
||||
### Examples
|
||||
|
||||
OmniLLM-3B is the first LMM deloyed on end devices. The demo video is the raw screen recording without edition.
|
||||
|
||||
<table align="center" >
|
||||
<p align="center" >
|
||||
<img src="assets/Snake_cn_Mushroom_en.gif" width=36%/>
|
||||
</p>
|
||||
</table>
|
||||
|
||||
## Demo
|
||||
Click here to try out the Demo of [OmniLMM-12B](http://120.92.209.146:8081) and [OmniLMM-3B](http://120.92.209.146:80).
|
||||
|
||||
## Install
|
||||
|
||||
1. Clone this repository and navigate to the source folder
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OpenBMB/OmniLMM.git
|
||||
cd OmniLMM
|
||||
```
|
||||
|
||||
2. Create conda environment
|
||||
|
||||
```Shell
|
||||
conda create -n OmniLMM python=3.10 -y
|
||||
conda activate OmniLMM
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
### Model Zoo
|
||||
| Model | Description | Download Link |
|
||||
|:----------------------|:-------------------|:---------------:|
|
||||
| OmniLMM-12B | The most capable version with strong performance. | [🤗](https://huggingface.co/openbmb/OmniLMM-12B) <a url="https://modelscope.cn/models/OpenBMB/OmniLMM-12B/files"> <img src="./assets/modelscope_logo.png" width="20px"></img></a> |
|
||||
| OmniLMM-3B | The efficient version for end device deployment. | [🤗](https://huggingface.co/openbmb/MiniCPM-V) <a url="https://modelscope.cn/models/OpenBMB/MiniCPM-V/files"> <img src="./assets/modelscope_logo.png" width="20px"></img></a> |
|
||||
|
||||
|
||||
### Multi-turn Conversation
|
||||
Please refer to the following codes to run `OmniLMM`.
|
||||
|
||||
<div align="center">
|
||||
<img src="assets/COCO_test2015_000000262144.jpg" width="660px">
|
||||
</div>
|
||||
|
||||
|
||||
```python
|
||||
from chat import OmniLMMChat, img2base64
|
||||
|
||||
chat_model = OmniLMMChat('openbmb/OmniLMM-12B') # or 'openbmb/MiniCPM-V'
|
||||
|
||||
im_64 = img2base64('./assets/COCO_test2015_000000262144.jpg')
|
||||
|
||||
# First round chat
|
||||
msgs = [{"role": "user", "content": "What are the people doing?"}]
|
||||
|
||||
inputs = {"image": im_64, "question": json.dumps(msgs)}
|
||||
answer = chat_model.process(inputs)
|
||||
print(answer)
|
||||
|
||||
# Second round chat
|
||||
# pass history context of multi-turn conversation
|
||||
msgs.append({"role": "assistant", "content": answer})
|
||||
msgs.append({"role": "user", "content": "Describe the image"})
|
||||
|
||||
inputs = {"image": im_64, "question": json.dumps(msgs)}
|
||||
answer = chat_model.process(inputs)
|
||||
print(answer)
|
||||
```
|
||||
|
||||
We can obtain the following results:
|
||||
```
|
||||
"The people in the image are playing baseball. One person is pitching a ball, another one is swinging a bat to hit it, and there's also an umpire present who appears to be watching the game closely."
|
||||
|
||||
"The image depicts a baseball game in progress. A pitcher is throwing the ball, while another player is swinging his bat to hit it. An umpire can be seen observing the play closely."
|
||||
```
|
||||
|
||||
|
||||
## ✅ TODO
|
||||
|
||||
- [ ] Fine-tuning support
|
||||
- [ ] Local Web-UI deployment
|
||||
- [ ] Code release for real-time interactive assistant
|
||||
|
||||
## Model License
|
||||
|
||||
The code in this repo is released according to [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE)
|
||||
|
||||
The usage of OmniLMMs' parameters is subject to "[General Model License Agreement - Source Notes - Publicity Restrictions - Commercial License](https://github.com/OpenBMB/General-Model-License/blob/main/通用模型许可协议-来源说明-宣传限制-商业授权.md)"
|
||||
|
||||
The parameters are fully open to acedemic research
|
||||
|
||||
Please contact cpm@modelbest.cn to obtain a written authorization for commercial uses. Free commercial use is also allowed after registration.
|
||||
|
||||
## Statement
|
||||
|
||||
As LMMs, OmniLMMs generate contents by learning a large mount of multimodal corpora, but it cannot comprehend, express personal opinions or make value judgement. Anything generated by OmniLMMs does not represent the views and positions of the model developers
|
||||
|
||||
We will not be liable for any problems arising from the use of OmniLMM open source models, including but not limited to data security issues, risk of public opinion, or any risks and problems arising from the misdirection, misuse, dissemination or misuse of the model.
|
||||
|
||||
|
||||
## 🏫 Institutions
|
||||
|
||||
This project is developed by the following institutions:
|
||||
|
||||
- <img src="assets/thunlp.png" width="28px"> [THUNLP](https://nlp.csai.tsinghua.edu.cn/)
|
||||
- <img src="assets/modelbest.png" width="28px"> [ModelBest](https://modelbest.cn/)
|
||||
- <img src="assets/zhihu.webp" width="28px"> [Zhihu](https://www.zhihu.com/ )
|
||||
|
390
README_zh.md
Normal file
@ -0,0 +1,390 @@
|
||||
<div align="center">
|
||||
|
||||
<!-- <!-- <h1 style="color: #33A6B8; font-family: Helvetica"> OmniLMM </h1> -->
|
||||
|
||||
<img src="./assets/title-2.png" width="200em" ></img>
|
||||
|
||||
**性能强大且部署高效的多模态大模型**
|
||||
<p align="center">
|
||||
OmniLMM-3B <a href="https://huggingface.co/openbmb/MiniCPM-V/">🤗</a> <a href="http://120.92.209.146:80/">🤖</a> |
|
||||
OmniLMM-12B <a href="https://huggingface.co/openbmb/OmniLMM-12B/">🤗</a> <a href="http://120.92.209.146:8081">🤖</a>
|
||||
</p>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
**OmniLMM** 是一系列善于处理图文输入的开源多模态大模型(LMMs)。该系列模型接受图像和文本输入,并提供高质量的文本输出。我们发布了两个版本的 OmniLMM,旨在实现**强大的性能和高效的部署**:
|
||||
|
||||
- **OmniLMM-12B**:相比同规模其他模型在多个基准测试中具有领先性能。
|
||||
|
||||
- **OmniLMM-3B**:可在终端设备上部署并具备先进的多模态对话能力。
|
||||
|
||||
[English Document](./README.md)
|
||||
|
||||
## 目录
|
||||
- [OmniLMM-12B](#omnilmm-12b)
|
||||
- [OmniLMM-3B](#omnilmm-3b)
|
||||
- [体验](#demo)
|
||||
- [安装](#install)
|
||||
- [推理](#inference)
|
||||
- [模型库](#model-zoo)
|
||||
|
||||
## OmniLMM-12B
|
||||
**OmniLMM-12B** 是当前系列中性能最强大的版本。该模型使用一个感知重采样层连接 EVA02-5B 和 Zephyr-7B-β 来构建,采用了课程学习的方法在多模态数据上进行训练。该模型具有三个显著特征:
|
||||
|
||||
- 🔥 **卓越性能。**
|
||||
|
||||
OmniLMM-12B 相比其他同规模模型在多个基准测试中取得**领先的性能**(包括 MME、MMBench、SEED-Bench 等)。该模型还**支持OCR功能**,并掌握了**丰富的多模态世界知识**。
|
||||
|
||||
- 🏆 **可信行为。**
|
||||
|
||||
LMMs 的幻觉问题备受关注,模型经常生成和图像中的事实不符的文本(例如,信誓旦旦地描述图片中并不存在的物体)。OmniLMM-12B是 **第一个通过多模态 RLHF 对齐的最新开源 LMM 来实现可信行为**(通过我们最近提出的 [RLHF-V](https://rlhf-v.github.io/) 技术)。该模型在 [MMHal-Bench](https://huggingface.co/datasets/Shengcao1006/MMHal-Bench) 幻觉评测基准上位列开源模型中**第一**,并在 [Object HalBench](https://arxiv.org/abs/2312.00849) 中**超过了 GPT-4V**。
|
||||
|
||||
- 🕹 **实时多模态交互。**
|
||||
|
||||
我们将 OmniLMM-12B 和 GPT-3.5 结合成一个**实时多模态交互助手**。该助手接受来自相机的视频流和来自麦克风的语音流,并发出语音输出。虽然还处于初级阶段,但我们也发现该模型**无需视频编辑**就可以**复现出现在 Gemini 演示视频中的一些有趣例子**。
|
||||
|
||||
### 性能评估
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>MME</th>
|
||||
<th nowrap="nowrap">MMB dev (en)</th>
|
||||
<th nowrap="nowrap" >MMMU val</th>
|
||||
<th nowrap="nowrap" >MMHal-Bench</th>
|
||||
<th nowrap="nowrap" >Object HalBench</th>
|
||||
<th nowrap="nowrap" >SeedBench-I</th>
|
||||
<th>MathVista</th>
|
||||
<th nowrap="nowrap" >LLaVA Bench W</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td align="left">GPT-4V†</td>
|
||||
<td>-</td>
|
||||
<td>1409</td>
|
||||
<td>75.1 </td>
|
||||
<td>56.8</td>
|
||||
<td>3.53 / 70.8</td>
|
||||
<td>86.4 / 92.7</td>
|
||||
<td>71.6 </td>
|
||||
<td>47.8 </td>
|
||||
<td>93.1 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Qwen-VL-Plus†</td>
|
||||
<td>-</td>
|
||||
<td>1681</td>
|
||||
<td>66.2 </td>
|
||||
<td>45.2</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>65.7 </td>
|
||||
<td>36.0 </td>
|
||||
<td>73.7 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">Yi-VL 6B</td>
|
||||
<td align="right">6.7B </td>
|
||||
<td>- </td>
|
||||
<td>68.2 </td>
|
||||
<td>39.1 </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>66.1 </td>
|
||||
<td>28.0 </td>
|
||||
<td>39.9 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >Qwen-VL-Chat</td>
|
||||
<td align="right">9.6B</td>
|
||||
<td>1488</td>
|
||||
<td>60.6 </td>
|
||||
<td>35.9</td>
|
||||
<td>2.93 / 59.4</td>
|
||||
<td>56.2 / 80.0</td>
|
||||
<td>64.8 </td>
|
||||
<td>33.8 </td>
|
||||
<td>67.7 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" >CogVLM</td>
|
||||
<td align="right">17.4B</td>
|
||||
<td>1438</td>
|
||||
<td>63.7 </td>
|
||||
<td>32.1 </td>
|
||||
<td>2.68 / 52.1 </td>
|
||||
<td>73.6 / 87.4 </td>
|
||||
<td>68.8 </td>
|
||||
<td>34.7 </td>
|
||||
<td>73.9 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" >LLaVA 1.5</td>
|
||||
<td align="right">13.6B </td>
|
||||
<td>1531 </td>
|
||||
<td>68.2 </td>
|
||||
<td>36.4 </td>
|
||||
<td>2.71 / 51.0 </td>
|
||||
<td>53.7 / 77.4 </td>
|
||||
<td>68.1 </td>
|
||||
<td>26.4 </td>
|
||||
<td>64.6 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" ><b>OmniLMM-12B</b></td>
|
||||
<td align="right">11.6B </td>
|
||||
<td>1637 </td>
|
||||
<td>71.6 </td>
|
||||
<td>40.7 </td>
|
||||
<td>3.45 / 68.8 </td>
|
||||
<td>90.3 / 95.5 </td>
|
||||
<td>71.1 </td>
|
||||
<td>34.9 </td>
|
||||
<td>72.0 </td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
<small>†: 闭源模型</small>
|
||||
|
||||
|
||||
## OmniLMM-3B
|
||||
|
||||
**OmniLMM-3B**(即 MiniCPM-V)是一种我们的高效率版本模型,可用于终端机器上的部署。该模型基于 SigLip-400M 和 MiniCPM-2.4B 构建,通过感知器重采样器连接。OmniLMM-3B的显著特点包括:
|
||||
|
||||
- ⚡️ **高效率。**
|
||||
|
||||
OmniLMM-3B 可以**高效地部署在大多数GPU卡和个人电脑上**,甚至**在移动手机等终端设备上**。在视觉编码方面,我们通过感知器重采样器将图像表示压缩为 64 个 token,远远少于基于MLP架构的其他LMMs(通常大于 512 个 token)。这使得 OmniLMM-3B 在推理期间**内存成本更低且速度更快**。
|
||||
|
||||
- 🔥 **优秀的性能。**
|
||||
|
||||
OmniLMM-3B 在与相似大小模型相比的多个基准测试中实现了**最先进的性能**,超过了基于 Phi-2构建的现有 LMMs。它甚至**实现了与9.6B Qwen-VL-Chat 相媲美或更好的性能**。
|
||||
|
||||
- 🙌 **双语支持。**
|
||||
|
||||
OmniLMM-3B 是**第一个支持英语和中文双语多模态交互的终端可部署 LMM**。这是通过跨语言泛化多模态能力实现的,这是我们 ICLR 2024 [spotlight 论文](https://arxiv.org/abs/2308.12038)中的一项技术。
|
||||
|
||||
### Evaluation
|
||||
|
||||
<div align="center">
|
||||
|
||||
<table style="margin: 0px auto;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>MME</th>
|
||||
<th nowrap="nowrap" >MMB dev (en)</th>
|
||||
<th nowrap="nowrap" >MMB dev (zh)</th>
|
||||
<th nowrap="nowrap" >MMMU val</th>
|
||||
<th nowrap="nowrap" >CMMMU val</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td align="left">LLaVA-Phi</td>
|
||||
<td align="right">3B</td>
|
||||
<td>1335</td>
|
||||
<td>59.8</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">MobileVLM</td>
|
||||
<td align="right">3B</td>
|
||||
<td>1289</td>
|
||||
<td>59.6</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >Imp-v1</td>
|
||||
<td align="right">3B</td>
|
||||
<td>1434</td>
|
||||
<td>66.5</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" >Qwen-VL-Chat</td>
|
||||
<td align="right" >9.6B</td>
|
||||
<td>1487</td>
|
||||
<td>60.6 </td>
|
||||
<td>56.7 </td>
|
||||
<td>35.9 </td>
|
||||
<td>30.7 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >CogVLM</td>
|
||||
<td align="right">17.4B </td>
|
||||
<td>1438 </td>
|
||||
<td>63.7 </td>
|
||||
<td>53.8 </td>
|
||||
<td>32.1 </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" ><b>OmniLMM-3B</b></td>
|
||||
<td align="right">3B </td>
|
||||
<td>1452 </td>
|
||||
<td>67.3 </td>
|
||||
<td>61.9 </td>
|
||||
<td>34.7 </td>
|
||||
<td>32.1 </td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
</div>
|
||||
|
||||
### 样例展示
|
||||
|
||||
<table align="center" >
|
||||
<p align="center" >
|
||||
<img src="assets/Snake_cn_Mushroom_en.gif" width=48%/>
|
||||
</p>
|
||||
</table>
|
||||
|
||||
## 体验
|
||||
|
||||
你可以通过以下链接尝试使用我们的网页端推理服务: [OmniLMM-12B](http://120.92.209.146:8081) | [OmniLMM-3B](http://120.92.209.146:80).
|
||||
|
||||
## 安装
|
||||
|
||||
1. Clone this repository and navigate to the source folder
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OpenBMB/OmniLMM.git
|
||||
cd OmniLMM
|
||||
```
|
||||
|
||||
2. Create conda environment
|
||||
|
||||
```Shell
|
||||
conda create -n OmniLMM python=3.10 -y
|
||||
conda activate OmniLMM
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 推理
|
||||
|
||||
### 模型库
|
||||
|
||||
| 模型 | 简介 | 下载链接 |
|
||||
|:----------------------|:-------------------|:---------------:|
|
||||
| OmniLMM-12B | 更强大的性能表现 | [🤗](https://huggingface.co/openbmb/OmniLMM-12B) <a url="https://modelscope.cn/models/OpenBMB/OmniLMM-12B/files"> <img src="./assets/modelscope_logo.png" width="20px"></img></a> |
|
||||
| OmniLMM-3B | 支持终端设备上的高效部署,性能优秀 | [🤗](https://huggingface.co/openbmb/MiniCPM-V) <a url="https://modelscope.cn/models/OpenBMB/MiniCPM-V/files"> <img src="./assets/modelscope_logo.png" width="20px"></img></a> |
|
||||
|
||||
|
||||
### 多轮对话
|
||||
|
||||
请参考以下代码运行 `OmniLMM` 的推理服务。
|
||||
|
||||
<div align="center">
|
||||
<img src="assets/COCO_test2015_000000262144.jpg" width="660px">
|
||||
</div>
|
||||
|
||||
##### OmniLMM-12B
|
||||
|
||||
```python
|
||||
from chat import OmniLMMChat, img2base64
|
||||
|
||||
chat_model = OmniLMMChat('openbmb/OmniLMM-12B')
|
||||
|
||||
im_64 = img2base64('./assets/COCO_test2015_000000262144.jpg')
|
||||
|
||||
# First round chat
|
||||
msgs = [{"role": "user", "content": "What are the people doing?"}]
|
||||
|
||||
inputs = {"image": im_64, "question": json.dumps(msgs)}
|
||||
answer = chat_model.process(inputs)
|
||||
print(answer)
|
||||
|
||||
# Second round chat
|
||||
# pass history context of multi-turn conversation
|
||||
msgs.append({"role": "assistant", "content": answer})
|
||||
msgs.append({"role": "user", "content": "Describe the image"})
|
||||
|
||||
inputs = {"image": im_64, "question": json.dumps(msgs)}
|
||||
answer = chat_model.process(inputs)
|
||||
print(answer)
|
||||
```
|
||||
|
||||
We can obtain the following results:
|
||||
```
|
||||
"The people in the image are playing baseball. One person is pitching a ball, another one is swinging a bat to hit it, and there's also an umpire present who appears to be watching the game closely."
|
||||
|
||||
"The image depicts a baseball game in progress. A pitcher is throwing the ball, while another player is swinging his bat to hit it. An umpire can be seen observing the play closely."
|
||||
```
|
||||
|
||||
##### OmniLMM-3B
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
model_path='openbmb/MiniCPM-V'
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model.eval().cuda()
|
||||
|
||||
image = Image.open('./assets/COCO_test2015_000000262144.jpg').convert('RGB')
|
||||
|
||||
question = '请描述一下该图像'
|
||||
res, context, _ = model.chat(
|
||||
image=image,
|
||||
question=question,
|
||||
context=None,
|
||||
tokenizer=tokenizer,
|
||||
sampling=True,
|
||||
temperature=0.7
|
||||
)
|
||||
print(res)
|
||||
```
|
||||
|
||||
## ✅ 未来计划
|
||||
|
||||
- [ ] 支持模型微调
|
||||
- [ ] 本地可视化部署
|
||||
- [ ] 实时多模态交互代码开源
|
||||
|
||||
|
||||
## 模型协议
|
||||
|
||||
本仓库中代码依照 Apache-2.0 协议开源
|
||||
|
||||
OmniLMMs 模型权重的使用则需要遵循 “[通用模型许可协议-来源说明-宣传限制-商业授权](https://github.com/OpenBMB/General-Model-License/blob/main/通用模型许可协议-来源说明-宣传限制-商业授权.md)”。
|
||||
|
||||
OmniLMMs 模型权重对学术研究完全开放。
|
||||
|
||||
如需将模型用于商业用途,请联系 cpm@modelbest.cn 来获取书面授权,在登记后亦允许免费商业使用。
|
||||
|
||||
|
||||
## 声明
|
||||
|
||||
作为多模态大模型,OmniLMMs 通过学习大量的多模态语料来生成内容,但它无法理解、表达个人观点或价值判断,它所输出的任何内容都不代表模型开发者的观点和立场。
|
||||
|
||||
因此用户在使用 OmniLMMs 生成的内容时,应自行负责对其进行评估和验证。
|
||||
|
||||
如果由于使用 OmniLMMs 开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
|
||||
|
||||
|
||||
## 🏫 机构
|
||||
|
||||
本项目由以下机构共同开发:
|
||||
|
||||
- <img src="assets/thunlp.png" width="28px"> [清华大学自然语言处理实验室](https://nlp.csai.tsinghua.edu.cn/)
|
||||
- <img src="assets/modelbest.png" width="28px"> [面壁智能](https://modelbest.cn/)
|
||||
- <img src="assets/zhihu.webp" width="28px"> [知乎](https://www.zhihu.com/ )
|
||||
|
BIN
assets/COCO_test2015_000000262144.jpg
Normal file
After Width: | Height: | Size: 188 KiB |
BIN
assets/Snake_cn_Mushroom_en.gif
Normal file
After Width: | Height: | Size: 3.4 MiB |
BIN
assets/demo_video.mp4
Normal file
BIN
assets/gif_cases/Mushroom_en.gif
Normal file
After Width: | Height: | Size: 819 KiB |
BIN
assets/gif_cases/Mushroom_en_Snake_cn.gif
Normal file
After Width: | Height: | Size: 3.4 MiB |
BIN
assets/gif_cases/Snake_en.gif
Normal file
After Width: | Height: | Size: 2.4 MiB |
BIN
assets/gif_cases/蘑菇_cn.gif
Normal file
After Width: | Height: | Size: 1.3 MiB |
BIN
assets/gif_cases/蛇_cn.gif
Normal file
After Width: | Height: | Size: 2.5 MiB |
BIN
assets/modelbest.png
Normal file
After Width: | Height: | Size: 48 KiB |
BIN
assets/modelscope_logo.png
Normal file
After Width: | Height: | Size: 6.0 KiB |
BIN
assets/omnilmm-12b-examples.png
Normal file
After Width: | Height: | Size: 5.1 MiB |
BIN
assets/thunlp.png
Normal file
After Width: | Height: | Size: 24 KiB |
BIN
assets/title-1.png
Normal file
After Width: | Height: | Size: 45 KiB |
BIN
assets/title-2.png
Normal file
After Width: | Height: | Size: 46 KiB |
BIN
assets/title-3.png
Normal file
After Width: | Height: | Size: 46 KiB |
BIN
assets/zhihu.webp
Normal file
After Width: | Height: | Size: 36 KiB |
194
chat.py
Normal file
@ -0,0 +1,194 @@
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
from PIL import Image
|
||||
import base64
|
||||
import io
|
||||
from accelerate import load_checkpoint_and_dispatch, init_empty_weights
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
from omnilmm.utils import disable_torch_init
|
||||
from omnilmm.model.omnilmm import OmniLMMForCausalLM
|
||||
from omnilmm.model.utils import build_transform
|
||||
from omnilmm.train.train_utils import omni_preprocess
|
||||
|
||||
DEFAULT_IMAGE_TOKEN = "<image>"
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
|
||||
|
||||
|
||||
def init_omni_lmm(model_path):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
disable_torch_init()
|
||||
model_name = os.path.expanduser(model_path)
|
||||
print(f'Load omni_lmm model and tokenizer from {model_name}')
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, model_max_length=2048)
|
||||
|
||||
if False:
|
||||
# model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
|
||||
with init_empty_weights():
|
||||
model = OmniLMMForCausalLM.from_pretrained(model_name, tune_clip=True, torch_dtype=torch.bfloat16)
|
||||
model = load_checkpoint_and_dispatch(model, model_name, dtype=torch.bfloat16,
|
||||
device_map="auto", no_split_module_classes=['Eva','MistralDecoderLayer', 'ModuleList', 'Resampler']
|
||||
)
|
||||
else:
|
||||
model = OmniLMMForCausalLM.from_pretrained(
|
||||
model_name, tune_clip=True, torch_dtype=torch.bfloat16
|
||||
).to(device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
image_processor = build_transform(
|
||||
is_train=False, input_size=model.model.config.image_size, std_mode='OPENAI_CLIP')
|
||||
|
||||
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
||||
assert mm_use_im_start_end
|
||||
|
||||
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
|
||||
DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||
|
||||
|
||||
vision_config = model.model.vision_config
|
||||
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
||||
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
||||
vision_config.use_im_start_end = mm_use_im_start_end
|
||||
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
||||
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
||||
image_token_len = model.model.config.num_query
|
||||
|
||||
return model, image_processor, image_token_len, tokenizer
|
||||
|
||||
def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
|
||||
if '<image>' in question_text[0]['content']:
|
||||
question_text[0]['content'] = question_text[0]['content'].replace(
|
||||
'<image>', im_st_token + im_patch_token * image_token_len + im_ed_token)
|
||||
else:
|
||||
question_text[0]['content'] = im_st_token + im_patch_token * \
|
||||
image_token_len + im_ed_token + '\n' + question_text[0]['content']
|
||||
return question_text
|
||||
|
||||
def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
|
||||
question = expand_question_into_multimodal(
|
||||
question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
|
||||
|
||||
conversation = question
|
||||
data_dict = omni_preprocess(sources=[conversation],
|
||||
tokenizer=tokenizer,
|
||||
generation=True)
|
||||
|
||||
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
||||
labels=data_dict["labels"][0])
|
||||
return data_dict
|
||||
|
||||
|
||||
|
||||
class OmniLMM12B:
|
||||
def __init__(self, model_path) -> None:
|
||||
model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
|
||||
self.model = model
|
||||
self.image_token_len = image_token_len
|
||||
self.image_transform = img_processor
|
||||
self.tokenizer = tokenizer
|
||||
self.model.eval()
|
||||
|
||||
def decode(self, image, input_ids):
|
||||
with torch.inference_mode():
|
||||
output = self.model.generate_vllm(
|
||||
input_ids=input_ids.unsqueeze(0).cuda(),
|
||||
images=image.unsqueeze(0).half().cuda(),
|
||||
temperature=0.6,
|
||||
max_new_tokens=1024,
|
||||
# num_beams=num_beams,
|
||||
do_sample=True,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
repetition_penalty=1.1,
|
||||
top_k=30,
|
||||
top_p=0.9,
|
||||
)
|
||||
|
||||
response = self.tokenizer.decode(
|
||||
output.sequences[0], skip_special_tokens=True)
|
||||
response = response.strip()
|
||||
return response
|
||||
|
||||
def chat(self, input):
|
||||
try:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
|
||||
except Exception as e:
|
||||
return "Image decode error"
|
||||
|
||||
msgs = json.loads(input['question'])
|
||||
input_ids = wrap_question_for_omni_lmm(
|
||||
msgs, self.image_token_len, self.tokenizer)['input_ids']
|
||||
input_ids = torch.as_tensor(input_ids)
|
||||
#print('input_ids', input_ids)
|
||||
image = self.image_transform(image)
|
||||
|
||||
out = self.decode(image, input_ids)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def img2base64(file_name):
|
||||
with open(file_name, 'rb') as f:
|
||||
encoded_string = base64.b64encode(f.read())
|
||||
return encoded_string
|
||||
|
||||
class OmniLMM3B:
|
||||
def __init__(self, model_path) -> None:
|
||||
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
self.model.eval().cuda()
|
||||
|
||||
def chat(self, input):
|
||||
try:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
|
||||
except Exception as e:
|
||||
return "Image decode error"
|
||||
|
||||
msgs = json.loads(input['question'])
|
||||
|
||||
answer, context, _ = self.model.chat(
|
||||
image=image,
|
||||
msgs=msgs,
|
||||
context=None,
|
||||
tokenizer=self.tokenizer,
|
||||
sampling=True,
|
||||
temperature=0.7
|
||||
)
|
||||
return answer
|
||||
|
||||
|
||||
class OmniLMMChat:
|
||||
def __init__(self, model_path) -> None:
|
||||
if '12B' in model_path:
|
||||
self.model = OmniLMM12B(model_path)
|
||||
else:
|
||||
self.model = OmniLMM3B(model_path)
|
||||
|
||||
def chat(self, input):
|
||||
return self.model.chat(input)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model_path = 'openbmb/OmniLMM-12B'
|
||||
chat_model = OmniLMMChat(model_path)
|
||||
|
||||
im_64 = img2base64('./assets/COCO_test2015_000000262144.jpg')
|
||||
|
||||
# first round chat
|
||||
msgs = [{"role": "user", "content": "What are the people doing?"}]
|
||||
input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
|
||||
answer = chat_model.chat(input)
|
||||
print(msgs[-1]["content"]+'\n', answer)
|
||||
|
||||
# second round chat
|
||||
msgs.append({"role": "assistant", "content": answer})
|
||||
msgs.append({"role": "user", "content": "Describe the image"})
|
||||
input = {"image": im_64,"question": json.dumps(msgs, ensure_ascii=True)}
|
||||
answer = chat_model.chat(input)
|
||||
print(msgs[-1]["content"]+'\n', answer)
|
||||
|
0
omnilmm/__init__.py
Normal file
4
omnilmm/constants.py
Normal file
@ -0,0 +1,4 @@
|
||||
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
||||
WORKER_HEART_BEAT_INTERVAL = 15
|
||||
|
||||
LOGDIR = "."
|
320
omnilmm/conversation.py
Normal file
@ -0,0 +1,320 @@
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
version: str = "Unknown"
|
||||
|
||||
skip_next: bool = False
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message, _, _ = message
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message, _, _ = message
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def get_images(self, return_pil=False):
|
||||
images = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
if type(msg) is tuple:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
msg, image, image_process_mode = msg
|
||||
if image_process_mode == "Pad":
|
||||
def expand2square(pil_img, background_color=(122, 116, 104)):
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(
|
||||
pil_img.mode, (width, width), background_color)
|
||||
result.paste(
|
||||
pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(
|
||||
pil_img.mode, (height, height), background_color)
|
||||
result.paste(
|
||||
pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
image = expand2square(image)
|
||||
elif image_process_mode == "Crop":
|
||||
pass
|
||||
elif image_process_mode == "Resize":
|
||||
image = image.resize((224, 224))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid image_process_mode: {image_process_mode}")
|
||||
max_hw, min_hw = max(image.size), min(image.size)
|
||||
aspect_ratio = max_hw / min_hw
|
||||
max_len, min_len = 800, 400
|
||||
shortest_edge = int(
|
||||
min(max_len / aspect_ratio, min_len, min_hw))
|
||||
longest_edge = int(shortest_edge * aspect_ratio)
|
||||
W, H = image.size
|
||||
if H > W:
|
||||
H, W = longest_edge, shortest_edge
|
||||
else:
|
||||
H, W = shortest_edge, longest_edge
|
||||
image = image.resize((W, H))
|
||||
if return_pil:
|
||||
images.append(image)
|
||||
else:
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_b64_str = base64.b64encode(
|
||||
buffered.getvalue()).decode()
|
||||
images.append(img_b64_str)
|
||||
return images
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
if type(msg) is tuple:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
msg, image, image_process_mode = msg
|
||||
max_hw, min_hw = max(image.size), min(image.size)
|
||||
aspect_ratio = max_hw / min_hw
|
||||
max_len, min_len = 800, 400
|
||||
shortest_edge = int(
|
||||
min(max_len / aspect_ratio, min_len, min_hw))
|
||||
longest_edge = int(shortest_edge * aspect_ratio)
|
||||
W, H = image.size
|
||||
if H > W:
|
||||
H, W = longest_edge, shortest_edge
|
||||
else:
|
||||
H, W = shortest_edge, longest_edge
|
||||
image = image.resize((W, H))
|
||||
# image = image.resize((224, 224))
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_b64_str = base64.b64encode(
|
||||
buffered.getvalue()).decode()
|
||||
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
||||
msg = msg.replace('<image>', img_str)
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2)
|
||||
|
||||
def dict(self):
|
||||
if len(self.get_images()) > 0:
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
}
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
}
|
||||
|
||||
|
||||
conv_v1 = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(
|
||||
("Human", "Give three tips for staying healthy."),
|
||||
("Assistant",
|
||||
"Sure, here are three tips for staying healthy:\n"
|
||||
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
||||
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
||||
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
|
||||
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
|
||||
"activities at least two days per week.\n"
|
||||
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
|
||||
"vegetables, whole grains, lean proteins, and healthy fats can help support "
|
||||
"your overall health. Try to limit your intake of processed and high-sugar foods, "
|
||||
"and aim to drink plenty of water throughout the day.\n"
|
||||
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
||||
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
||||
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
||||
"help improve the quality of your sleep.")
|
||||
),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
conv_v1_2 = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(
|
||||
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
||||
("Assistant",
|
||||
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
||||
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
||||
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
||||
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
||||
"renewable and non-renewable energy sources:\n"
|
||||
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
||||
"energy sources are finite and will eventually run out.\n"
|
||||
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
||||
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
||||
"and other negative effects.\n"
|
||||
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
||||
"have lower operational costs than non-renewable sources.\n"
|
||||
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
||||
"locations than non-renewable sources.\n"
|
||||
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
||||
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
||||
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
||||
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
||||
),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
conv_vicuna_v1_1 = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
version="v1",
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_bair_v1 = Conversation(
|
||||
system="BEGINNING OF CONVERSATION:",
|
||||
roles=("USER", "GPT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
simple_conv = Conversation(
|
||||
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
|
||||
"You are designed to assist human with a variety of tasks using natural language."
|
||||
"Follow the instructions carefully.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(
|
||||
("Human", "Hi!"),
|
||||
("Assistant", "Hi there! How can I help you today?\n")
|
||||
),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
simple_conv_multimodal = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(
|
||||
),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
simple_conv_legacy = Conversation(
|
||||
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
|
||||
"You are designed to assist human with a variety of tasks using natural language."
|
||||
"Follow the instructions carefully.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(
|
||||
("Human", "Hi!\n\n### Response:"),
|
||||
("Assistant", "Hi there! How can I help you today?\n")
|
||||
),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
conv_llava_v1 = Conversation(
|
||||
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
||||
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
||||
"Follow the instructions carefully and explain your answers in detail.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
version="v1",
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
default_conversation = conv_v1_2
|
||||
conv_templates = {
|
||||
"default": conv_v1_2,
|
||||
"simple": simple_conv,
|
||||
"simple_legacy": simple_conv_legacy,
|
||||
"multimodal": simple_conv_multimodal,
|
||||
"llava_v1": conv_llava_v1,
|
||||
|
||||
# fastchat
|
||||
"v1": conv_v1_2,
|
||||
"bair_v1": conv_bair_v1,
|
||||
"vicuna_v1_1": conv_vicuna_v1_1,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(default_conversation.get_prompt())
|
1
omnilmm/model/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .omnilmm import OmniLMMForCausalLM
|
457
omnilmm/model/omnilmm.py
Normal file
@ -0,0 +1,457 @@
|
||||
|
||||
import gc
|
||||
import math
|
||||
import timm
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers import MistralForCausalLM, MistralModel, MistralConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
|
||||
from omnilmm.model.utils import build_transform
|
||||
from omnilmm.model.resampler import Resampler
|
||||
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
|
||||
|
||||
class OmniLMMConfig(MistralConfig):
|
||||
model_type = "omnilmm"
|
||||
|
||||
|
||||
class Identity(torch.nn.Identity):
|
||||
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
def create_vision_module(config):
|
||||
vision_tower = timm.create_model('eva02_enormous_patch14_clip_224.laion2b_plus',
|
||||
pretrained=False,
|
||||
num_classes=0,
|
||||
dynamic_img_size=True,
|
||||
dynamic_img_pad=True)
|
||||
|
||||
if isinstance(vision_tower, timm.models.VisionTransformer):
|
||||
if vision_tower.attn_pool is not None:
|
||||
vision_tower.attn_pool = Identity()
|
||||
|
||||
# use 2nd last layer's output
|
||||
vision_tower.blocks[-1] = Identity()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
resampler = Resampler(
|
||||
grid_size=int(math.sqrt(config.num_query)),
|
||||
embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_tower.embed_dim,
|
||||
)
|
||||
return vision_tower, resampler
|
||||
|
||||
|
||||
class OmniLMMModel(MistralModel):
|
||||
config_class = OmniLMMConfig
|
||||
|
||||
def __init__(self, config: OmniLMMConfig, mm_vision_tower=None, mm_hidden_size=None, tune_clip=True):
|
||||
super(OmniLMMModel, self).__init__(config)
|
||||
|
||||
if hasattr(config, "mm_vision_tower"):
|
||||
vision_tower, resampler = create_vision_module(config)
|
||||
|
||||
print(__file__, 'skip loading vision tower weights')
|
||||
|
||||
# HACK: for FSDP
|
||||
self.vision_tower = [vision_tower]
|
||||
self.resampler = resampler
|
||||
if tune_clip:
|
||||
self.vision_tower = self.vision_tower[0]
|
||||
|
||||
self.vision_config = lambda x: None
|
||||
|
||||
def initialize_vision_modules(self, vision_tower, no_randaug, num_query, image_size, tune_clip=False):
|
||||
self.config.mm_vision_tower = vision_tower
|
||||
self.config.use_mm_proj = True
|
||||
self.config.num_query = num_query
|
||||
self.config.image_size = image_size
|
||||
|
||||
if not hasattr(self, 'vision_tower'):
|
||||
vision_tower, resampler = create_vision_module(self.config)
|
||||
state_dict = torch.load(
|
||||
'/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt')
|
||||
vision_tower.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
else:
|
||||
if isinstance(self.vision_tower, list):
|
||||
vision_tower = self.vision_tower[0]
|
||||
else:
|
||||
vision_tower = self.vision_tower
|
||||
resampler = self.resampler
|
||||
self.vision_tower = vision_tower if tune_clip else [vision_tower]
|
||||
self.resampler = resampler
|
||||
|
||||
train_img_transform = build_transform(
|
||||
is_train=True, randaug=not no_randaug, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
|
||||
eval_img_transform = build_transform(
|
||||
is_train=False, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
|
||||
|
||||
return dict(
|
||||
image_processor=(train_img_transform, eval_img_transform),
|
||||
image_token_len=num_query,
|
||||
vision_config=self.vision_config
|
||||
)
|
||||
|
||||
def get_vision_embedding(self, pixel_values):
|
||||
if isinstance(self.vision_tower, list):
|
||||
vision_tower = self.vision_tower[0] # HACK: for FSDP
|
||||
else:
|
||||
vision_tower = self.vision_tower
|
||||
|
||||
dtype = vision_tower.pos_embed.data.dtype
|
||||
vision_embedding = vision_tower.forward_features(
|
||||
pixel_values.type(dtype))
|
||||
if hasattr(vision_tower, 'num_prefix_tokens') and vision_tower.num_prefix_tokens > 0:
|
||||
vision_embedding = vision_embedding[:,
|
||||
vision_tower.num_prefix_tokens:]
|
||||
res = self.resampler(vision_embedding)
|
||||
return res
|
||||
|
||||
def get_vllm_embedding(self, data):
|
||||
|
||||
if 'vision_hidden_states' not in data:
|
||||
pixel_values_list = data['pixel_values']
|
||||
vision_hidden_states = []
|
||||
for pixel_values in pixel_values_list:
|
||||
if len(pixel_values) > 0:
|
||||
vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))[0])
|
||||
else:
|
||||
vision_hidden_states.append([])
|
||||
else:
|
||||
vision_hidden_states = data['vision_hidden_states']
|
||||
|
||||
#vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
||||
inputs_embeds = self.embed_tokens(data['input_ids'])
|
||||
vision_hidden_states = [i.type(inputs_embeds.dtype)
|
||||
if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
||||
]
|
||||
|
||||
|
||||
# HACK: replace back original embeddings for LLaVA pretraining
|
||||
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
||||
|
||||
new_input_embeds = []
|
||||
cur_image_idx = 0
|
||||
for cur_input_ids, cur_input_embeds in zip(data['input_ids'], inputs_embeds):
|
||||
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
continue
|
||||
|
||||
if self.vision_config.use_im_start_end:
|
||||
cur_image_features = vision_hidden_states[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
|
||||
raise ValueError(
|
||||
"The number of image start tokens and image end tokens should be the same.")
|
||||
image_start_tokens = torch.where(
|
||||
cur_input_ids == self.vision_config.im_start_token)[0]
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = vision_hidden_states[cur_image_idx].to(
|
||||
device=cur_input_embeds.device)
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
|
||||
raise ValueError(
|
||||
"The image end token should follow the image start token.")
|
||||
if orig_embeds_params is not None:
|
||||
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
||||
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
||||
else:
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
||||
cur_image_idx += 1
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
|
||||
return inputs_embeds, vision_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
images: Optional[torch.FloatTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
# HACK: replace back original embeddings for LLaVA pretraining
|
||||
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
||||
|
||||
if inputs_embeds is None and past_key_values is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
vision_tower = getattr(self, 'vision_tower', None)
|
||||
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
|
||||
|
||||
if type(images) is list:
|
||||
image_features = []
|
||||
for image in images:
|
||||
image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[
|
||||
0]
|
||||
image_features.append(image_forward_out)
|
||||
else:
|
||||
image_features = self.get_vision_embedding(images)
|
||||
|
||||
dummy_image_features = torch.zeros(
|
||||
self.config.num_query,
|
||||
self.config.hidden_size,
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype)
|
||||
|
||||
new_input_embeds = []
|
||||
cur_image_idx = 0
|
||||
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
||||
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
cur_input_embeds = cur_input_embeds + \
|
||||
(0. * dummy_image_features).sum()
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
continue
|
||||
|
||||
if self.vision_config.use_im_start_end:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
|
||||
raise ValueError(
|
||||
"The number of image start tokens and image end tokens should be the same.")
|
||||
image_start_tokens = torch.where(
|
||||
cur_input_ids == self.vision_config.im_start_token)[0]
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = image_features[cur_image_idx].to(
|
||||
device=cur_input_embeds.device)
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
|
||||
raise ValueError(
|
||||
"The image end token should follow the image start token.")
|
||||
if orig_embeds_params is not None:
|
||||
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
||||
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
||||
else:
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
||||
cur_image_idx += 1
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
input_ids = None
|
||||
|
||||
return super(OmniLMMModel, self).forward(
|
||||
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds, use_cache=use_cache,
|
||||
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class OmniLMMForCausalLM(MistralForCausalLM):
|
||||
config_class = OmniLMMConfig
|
||||
|
||||
def __init__(self, config, mm_vision_tower=None, tune_clip=True):
|
||||
super(MistralForCausalLM, self).__init__(config)
|
||||
self.model = OmniLMMModel(
|
||||
config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip)
|
||||
|
||||
self.lm_head = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
images: Optional[torch.FloatTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
|
||||
# print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
|
||||
# print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
images=images,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model/pipeline parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# TODO could be removed for generate_vllm()
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"images": kwargs.get("images", None),
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def generate_vllm(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
images: Optional[torch.FloatTensor] = None,
|
||||
vision_hidden_states=None,
|
||||
return_vision_hidden_states=False,
|
||||
**kwargs
|
||||
):
|
||||
model_inputs = {'input_ids': input_ids}
|
||||
if vision_hidden_states is None:
|
||||
model_inputs['pixel_values'] = images
|
||||
else:
|
||||
model_inputs['vision_hidden_states'] = vision_hidden_states
|
||||
|
||||
with torch.inference_mode():
|
||||
inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(model_inputs)
|
||||
|
||||
result = self.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if return_vision_hidden_states:
|
||||
return result, vision_hidden_states
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
|
||||
tune_mm_mlp_adapter=False):
|
||||
self.model.vision_config.use_im_start_end = mm_use_im_start_end
|
||||
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
||||
self.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if mm_use_im_start_end:
|
||||
num_new_tokens = tokenizer.add_tokens(
|
||||
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||
self.resize_token_embeddings(len(tokenizer))
|
||||
self.model.vision_config.im_start_token, self.model.vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
||||
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = self.get_input_embeddings().weight.data
|
||||
output_embeddings = self.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
# for new sft data
|
||||
num_new_tokens = tokenizer.add_tokens(
|
||||
['<box>', '</box>', '<ref>', '</ref>', '<quad>', '</quad>'], special_tokens=True)
|
||||
self.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = self.get_input_embeddings().weight.data
|
||||
output_embeddings = self.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
if tune_mm_mlp_adapter:
|
||||
self.model.orig_embeds_params = [
|
||||
self.get_input_embeddings().weight.data.clone().to(device=device)]
|
||||
for p in self.get_input_embeddings().parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.get_output_embeddings().parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
||||
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
||||
print(f'Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}', flush=True)
|
||||
# exit()
|
||||
|
||||
|
||||
AutoConfig.register("omnilmm", OmniLMMConfig)
|
||||
AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)
|
171
omnilmm/model/resampler.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
from typing import Callable, Optional, Sequence, Tuple, List, Union
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import trunc_normal_
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
|
||||
def get_abs_pos(abs_pos, tgt_size):
|
||||
# abs_pos: L, C
|
||||
# tgt_size: M
|
||||
# return: M, C
|
||||
src_size = int(math.sqrt(abs_pos.size(0)))
|
||||
tgt_size = int(math.sqrt(tgt_size))
|
||||
dtype = abs_pos.dtype
|
||||
|
||||
if src_size != tgt_size:
|
||||
return F.interpolate(
|
||||
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
||||
size=(tgt_size, tgt_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
||||
else:
|
||||
return abs_pos
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
"""
|
||||
A 2D perceiver-resampler network with one cross attention layers by
|
||||
(grid_size**2) learnable queries and 2d sincos pos_emb
|
||||
Outputs:
|
||||
A tensor with the shape of (grid_size**2, embed_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid_size,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim=None,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
||||
):
|
||||
super().__init__()
|
||||
self.num_queries = grid_size ** 2
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.from_numpy(get_2d_sincos_pos_embed(
|
||||
embed_dim, grid_size)).float()
|
||||
).requires_grad_(False)
|
||||
|
||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||
trunc_normal_(self.query, std=.02)
|
||||
|
||||
if kv_dim is not None and kv_dim != embed_dim:
|
||||
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
||||
else:
|
||||
self.kv_proj = nn.Identity()
|
||||
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.ln_q = norm_layer(embed_dim)
|
||||
self.ln_kv = norm_layer(embed_dim)
|
||||
|
||||
self.ln_post = norm_layer(embed_dim)
|
||||
self.proj = nn.Parameter(
|
||||
(embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
|
||||
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
||||
|
||||
x = self.kv_proj(x)
|
||||
x = self.ln_kv(x).permute(1, 0, 2)
|
||||
|
||||
N = x.shape[1]
|
||||
q = self.ln_q(self.query)
|
||||
# print((self._repeat(q, N) + self.pos_embed.unsqueeze(1)).dtype, (x + pos_embed.unsqueeze(1)).dtype, x.dtype)
|
||||
out = self.attn(
|
||||
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
||||
x + pos_embed.unsqueeze(1),
|
||||
x,
|
||||
attn_mask=attn_mask)[0]
|
||||
x = out.permute(1, 0, 2)
|
||||
|
||||
x = self.ln_post(x)
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
def _repeat(self, query, N: int):
|
||||
return query.unsqueeze(1).repeat(1, N, 1)
|
555
omnilmm/model/utils.py
Normal file
@ -0,0 +1,555 @@
|
||||
from torchvision import transforms
|
||||
from timm.data.transforms import RandomResizedCropAndInterpolation
|
||||
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from transformers import AutoConfig
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
import pickle
|
||||
import base64
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoConfig, StoppingCriteria
|
||||
|
||||
try:
|
||||
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
except ImportError:
|
||||
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
|
||||
def auto_upgrade(config):
|
||||
cfg = AutoConfig.from_pretrained(config)
|
||||
if 'llava' in config and cfg.model_type != 'llava':
|
||||
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
||||
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
||||
confirm = input(
|
||||
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
||||
if confirm.lower() in ["y", "yes"]:
|
||||
print("Upgrading checkpoint...")
|
||||
assert len(cfg.architectures) == 1
|
||||
setattr(cfg.__class__, "model_type", "llava")
|
||||
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
||||
cfg.save_pretrained(config)
|
||||
print("Checkpoint upgraded.")
|
||||
else:
|
||||
print("Checkpoint upgrade aborted.")
|
||||
exit(1)
|
||||
|
||||
|
||||
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, keywords, tokenizer, input_ids):
|
||||
self.keywords = keywords
|
||||
self.tokenizer = tokenizer
|
||||
self.start_len = None
|
||||
self.input_ids = input_ids
|
||||
|
||||
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
if self.start_len is None:
|
||||
self.start_len = self.input_ids.shape[1]
|
||||
else:
|
||||
outputs = self.tokenizer.batch_decode(
|
||||
output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
||||
for keyword in self.keywords:
|
||||
if keyword in outputs:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def auto_upgrade(config):
|
||||
cfg = AutoConfig.from_pretrained(config)
|
||||
if 'llava' in config and cfg.model_type != 'llava':
|
||||
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
||||
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
||||
confirm = input(
|
||||
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
||||
if confirm.lower() in ["y", "yes"]:
|
||||
print("Upgrading checkpoint...")
|
||||
assert len(cfg.architectures) == 1
|
||||
setattr(cfg.__class__, "model_type", "llava")
|
||||
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
||||
cfg.save_pretrained(config)
|
||||
print("Checkpoint upgraded.")
|
||||
else:
|
||||
print("Checkpoint upgrade aborted.")
|
||||
exit(1)
|
||||
|
||||
# aug functions
|
||||
|
||||
|
||||
def identity_func(img):
|
||||
return img
|
||||
|
||||
|
||||
def autocontrast_func(img, cutoff=0):
|
||||
'''
|
||||
same output as PIL.ImageOps.autocontrast
|
||||
'''
|
||||
n_bins = 256
|
||||
|
||||
def tune_channel(ch):
|
||||
n = ch.size
|
||||
cut = cutoff * n // 100
|
||||
if cut == 0:
|
||||
high, low = ch.max(), ch.min()
|
||||
else:
|
||||
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
||||
low = np.argwhere(np.cumsum(hist) > cut)
|
||||
low = 0 if low.shape[0] == 0 else low[0]
|
||||
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
||||
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
||||
if high <= low:
|
||||
table = np.arange(n_bins)
|
||||
else:
|
||||
scale = (n_bins - 1) / (high - low)
|
||||
table = np.arange(n_bins) * scale - low * scale
|
||||
table[table < 0] = 0
|
||||
table[table > n_bins - 1] = n_bins - 1
|
||||
table = table.clip(0, 255).astype(np.uint8)
|
||||
return table[ch]
|
||||
|
||||
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
||||
out = cv2.merge(channels)
|
||||
return out
|
||||
|
||||
|
||||
def equalize_func(img):
|
||||
'''
|
||||
same output as PIL.ImageOps.equalize
|
||||
PIL's implementation is different from cv2.equalize
|
||||
'''
|
||||
n_bins = 256
|
||||
|
||||
def tune_channel(ch):
|
||||
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
||||
non_zero_hist = hist[hist != 0].reshape(-1)
|
||||
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
||||
if step == 0:
|
||||
return ch
|
||||
n = np.empty_like(hist)
|
||||
n[0] = step // 2
|
||||
n[1:] = hist[:-1]
|
||||
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
||||
return table[ch]
|
||||
|
||||
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
||||
out = cv2.merge(channels)
|
||||
return out
|
||||
|
||||
|
||||
def rotate_func(img, degree, fill=(0, 0, 0)):
|
||||
'''
|
||||
like PIL, rotate by degree, not radians
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
center = W / 2, H / 2
|
||||
M = cv2.getRotationMatrix2D(center, degree, 1)
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
||||
return out
|
||||
|
||||
|
||||
def solarize_func(img, thresh=128):
|
||||
'''
|
||||
same output as PIL.ImageOps.posterize
|
||||
'''
|
||||
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
||||
table = table.clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def color_func(img, factor):
|
||||
'''
|
||||
same output as PIL.ImageEnhance.Color
|
||||
'''
|
||||
# implementation according to PIL definition, quite slow
|
||||
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
||||
# out = blend(degenerate, img, factor)
|
||||
# M = (
|
||||
# np.eye(3) * factor
|
||||
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
||||
# )[np.newaxis, np.newaxis, :]
|
||||
M = (
|
||||
np.float32([
|
||||
[0.886, -0.114, -0.114],
|
||||
[-0.587, 0.413, -0.587],
|
||||
[-0.299, -0.299, 0.701]]) * factor
|
||||
+ np.float32([[0.114], [0.587], [0.299]])
|
||||
)
|
||||
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def contrast_func(img, factor):
|
||||
"""
|
||||
same output as PIL.ImageEnhance.Contrast
|
||||
"""
|
||||
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
||||
table = np.array([(
|
||||
el - mean) * factor + mean
|
||||
for el in range(256)
|
||||
]).clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def brightness_func(img, factor):
|
||||
'''
|
||||
same output as PIL.ImageEnhance.Contrast
|
||||
'''
|
||||
table = (np.arange(256, dtype=np.float32) *
|
||||
factor).clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def sharpness_func(img, factor):
|
||||
'''
|
||||
The differences the this result and PIL are all on the 4 boundaries, the center
|
||||
areas are same
|
||||
'''
|
||||
kernel = np.ones((3, 3), dtype=np.float32)
|
||||
kernel[1][1] = 5
|
||||
kernel /= 13
|
||||
degenerate = cv2.filter2D(img, -1, kernel)
|
||||
if factor == 0.0:
|
||||
out = degenerate
|
||||
elif factor == 1.0:
|
||||
out = img
|
||||
else:
|
||||
out = img.astype(np.float32)
|
||||
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
||||
out[1:-1, 1:-1, :] = degenerate + factor * \
|
||||
(out[1:-1, 1:-1, :] - degenerate)
|
||||
out = out.astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
||||
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
||||
'''
|
||||
same output as PIL.Image.transform
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
||||
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
||||
'''
|
||||
same output as PIL.Image.transform
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
||||
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def posterize_func(img, bits):
|
||||
'''
|
||||
same output as PIL.ImageOps.posterize
|
||||
'''
|
||||
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
||||
return out
|
||||
|
||||
|
||||
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
||||
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
||||
replace = np.array(replace, dtype=np.uint8)
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
rh, rw = np.random.random(2)
|
||||
pad_size = pad_size // 2
|
||||
ch, cw = int(rh * H), int(rw * W)
|
||||
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
||||
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
||||
out = img.copy()
|
||||
out[x1:x2, y1:y2, :] = replace
|
||||
return out
|
||||
|
||||
|
||||
# level to args
|
||||
def enhance_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
||||
return level_to_args
|
||||
|
||||
|
||||
def shear_level_to_args(MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * 0.3
|
||||
if np.random.random() > 0.5:
|
||||
level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * float(translate_const)
|
||||
if np.random.random() > 0.5:
|
||||
level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * cutout_const)
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def solarize_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * 256)
|
||||
return (level, )
|
||||
return level_to_args
|
||||
|
||||
|
||||
def none_level_to_args(level):
|
||||
return ()
|
||||
|
||||
|
||||
def posterize_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * 4)
|
||||
return (level, )
|
||||
return level_to_args
|
||||
|
||||
|
||||
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * 30
|
||||
if np.random.random() < 0.5:
|
||||
level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
func_dict = {
|
||||
'Identity': identity_func,
|
||||
'AutoContrast': autocontrast_func,
|
||||
'Equalize': equalize_func,
|
||||
'Rotate': rotate_func,
|
||||
'Solarize': solarize_func,
|
||||
'Color': color_func,
|
||||
'Contrast': contrast_func,
|
||||
'Brightness': brightness_func,
|
||||
'Sharpness': sharpness_func,
|
||||
'ShearX': shear_x_func,
|
||||
'TranslateX': translate_x_func,
|
||||
'TranslateY': translate_y_func,
|
||||
'Posterize': posterize_func,
|
||||
'ShearY': shear_y_func,
|
||||
}
|
||||
|
||||
translate_const = 10
|
||||
MAX_LEVEL = 10
|
||||
replace_value = (128, 128, 128)
|
||||
arg_dict = {
|
||||
'Identity': none_level_to_args,
|
||||
'AutoContrast': none_level_to_args,
|
||||
'Equalize': none_level_to_args,
|
||||
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
||||
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
||||
'Color': enhance_level_to_args(MAX_LEVEL),
|
||||
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
||||
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
||||
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
||||
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
||||
'TranslateX': translate_level_to_args(
|
||||
translate_const, MAX_LEVEL, replace_value
|
||||
),
|
||||
'TranslateY': translate_level_to_args(
|
||||
translate_const, MAX_LEVEL, replace_value
|
||||
),
|
||||
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
||||
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
||||
}
|
||||
|
||||
|
||||
class RandomAugment(object):
|
||||
|
||||
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
||||
self.N = N
|
||||
self.M = M
|
||||
self.isPIL = isPIL
|
||||
if augs:
|
||||
self.augs = augs
|
||||
else:
|
||||
self.augs = list(arg_dict.keys())
|
||||
|
||||
def get_random_ops(self):
|
||||
sampled_ops = np.random.choice(self.augs, self.N)
|
||||
return [(op, 0.5, self.M) for op in sampled_ops]
|
||||
|
||||
def __call__(self, img):
|
||||
if self.isPIL:
|
||||
img = np.array(img)
|
||||
ops = self.get_random_ops()
|
||||
for name, prob, level in ops:
|
||||
if np.random.random() > prob:
|
||||
continue
|
||||
args = arg_dict[name](level)
|
||||
img = func_dict[name](img, *args)
|
||||
return img
|
||||
|
||||
|
||||
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic', std_mode='IMAGENET_INCEPTION'):
|
||||
if std_mode == 'IMAGENET_INCEPTION':
|
||||
mean = IMAGENET_INCEPTION_MEAN
|
||||
std = IMAGENET_INCEPTION_STD
|
||||
elif std_mode == 'OPENAI_CLIP':
|
||||
mean = OPENAI_CLIP_MEAN
|
||||
std = OPENAI_CLIP_STD
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if is_train:
|
||||
crop_scale = float(os.environ.get('TRAIN_CROP_SCALE', 0.9999))
|
||||
t = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
input_size, scale=(crop_scale, 1.0), interpolation='bicubic'),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if randaug and os.environ.get('TRAIN_DO_AUG', 'False') == 'True':
|
||||
print(f'@@@@@ Do random aug during training', flush=True)
|
||||
t.append(
|
||||
RandomAugment(
|
||||
2, 7, isPIL=True,
|
||||
augs=[
|
||||
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
||||
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
|
||||
]))
|
||||
else:
|
||||
print(f'@@@@@ Skip random aug during training', flush=True)
|
||||
t += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
]
|
||||
t = transforms.Compose(t)
|
||||
else:
|
||||
t = transforms.Compose([
|
||||
transforms.Resize((input_size, input_size),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def img2b64(img_path):
|
||||
img = Image.open(img_path) # path to file
|
||||
img_buffer = BytesIO()
|
||||
img.save(img_buffer, format=img.format)
|
||||
byte_data = img_buffer.getvalue()
|
||||
base64_str = base64.b64encode(byte_data) # bytes
|
||||
base64_str = base64_str.decode("utf-8") # str
|
||||
return base64_str
|
||||
|
||||
|
||||
def str2b64(str):
|
||||
return base64.b64encode(str.encode('utf-8')).decode('utf-8')
|
||||
|
||||
|
||||
def b642str(b64):
|
||||
return base64.b64decode(b64).decode('utf-8')
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
||||
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def mean(lst):
|
||||
return sum(lst) / len(lst)
|
||||
|
||||
|
||||
def stop_gradient_by_name(name: str):
|
||||
def apply_fn(module):
|
||||
if hasattr(module, name):
|
||||
getattr(module, name).requires_grad_(False)
|
||||
|
||||
return apply_fn
|
153
omnilmm/train/train_utils.py
Normal file
@ -0,0 +1,153 @@
|
||||
import os
|
||||
import gc
|
||||
import copy
|
||||
import time
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
import transformers
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Dict, Optional, Sequence
|
||||
from omnilmm import conversation as conversation_lib
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_IMAGE_TOKEN = "<image>"
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
]
|
||||
input_ids = labels = [
|
||||
tokenized.input_ids[0] for tokenized in tokenized_list
|
||||
]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
||||
for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def omni_preprocess(sources,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
generation=False):
|
||||
system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
|
||||
ignore_index = -100
|
||||
|
||||
response_template = '\n<|assistant|>\n'
|
||||
instruction_template = '\n<|user|>\n'
|
||||
response_token_ids = tokenizer.encode(
|
||||
response_template, add_special_tokens=False)
|
||||
instruction_token_ids = tokenizer.encode(
|
||||
instruction_template, add_special_tokens=False)
|
||||
|
||||
batch_input_ids = []
|
||||
batch_labels = []
|
||||
for i in range(len(sources)):
|
||||
new_source = []
|
||||
prev_role = 'unexpect'
|
||||
for conv_turn in sources[i]:
|
||||
role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
|
||||
content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']
|
||||
|
||||
role = 'user' if role == 'human' else role
|
||||
role = 'assistant' if role == 'gpt' else role
|
||||
|
||||
assert role in ['user', 'assistant']
|
||||
assert role != prev_role, f'role={role}, prev_role={prev_role}'
|
||||
prev_role = role
|
||||
|
||||
new_turn = {
|
||||
'role': role,
|
||||
'content': content
|
||||
}
|
||||
new_source.append(new_turn)
|
||||
if new_source[0]['role'] != 'system':
|
||||
new_source.insert(0, {'role': 'system', 'content': system_content})
|
||||
|
||||
# TODO: this automatically add '\n' to the end
|
||||
res_text = tokenizer.apply_chat_template(
|
||||
new_source, tokenize=False, add_generation_prompt=generation)
|
||||
if not generation:
|
||||
res_text = res_text.strip()
|
||||
|
||||
conversations_tokenized = _tokenize_fn([res_text], tokenizer)
|
||||
res_input_ids = conversations_tokenized["input_ids"][0]
|
||||
|
||||
# since labels and input_ids are reference towards the same object
|
||||
res_labels = copy.deepcopy(conversations_tokenized["labels"][0])
|
||||
|
||||
response_token_ids_idxs = []
|
||||
human_token_ids_idxs = []
|
||||
|
||||
for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
|
||||
# find the indexes of the start of a response.
|
||||
if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
|
||||
response_token_ids)].tolist()
|
||||
):
|
||||
response_token_ids_idxs.append(
|
||||
assistant_idx + len(response_token_ids))
|
||||
|
||||
if len(response_token_ids_idxs) == 0:
|
||||
warnings.warn(
|
||||
f"Could not find response key `{response_template}` in the "
|
||||
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
|
||||
f'Raw text is @===>{res_text}<===@'
|
||||
f'Raw source is @===>{new_source}<===@'
|
||||
f"This instance will be ignored in loss calculation. "
|
||||
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
||||
)
|
||||
res_labels[:] = ignore_index
|
||||
|
||||
human_token_ids = instruction_token_ids
|
||||
for human_idx in np.where(res_labels == human_token_ids[0])[0]:
|
||||
# find the indexes of the start of a human answer.
|
||||
if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
|
||||
human_token_ids_idxs.append(human_idx)
|
||||
|
||||
if len(human_token_ids_idxs) == 0:
|
||||
warnings.warn(
|
||||
f"Could not find instruction key `{instruction_template}` in the "
|
||||
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
|
||||
f'Raw text is @===>{res_text}<===@'
|
||||
f'Raw source is @===>{new_source}<===@'
|
||||
f"This instance will be ignored in loss calculation. "
|
||||
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
||||
)
|
||||
res_labels[:] = ignore_index
|
||||
|
||||
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
|
||||
# Make pytorch loss function ignore all non response tokens
|
||||
if idx != 0:
|
||||
res_labels[start:end] = ignore_index
|
||||
else:
|
||||
res_labels[:end] = ignore_index
|
||||
|
||||
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
|
||||
res_labels[human_token_ids_idxs[-1]:] = ignore_index
|
||||
|
||||
batch_input_ids.append(res_input_ids)
|
||||
batch_labels.append(res_labels)
|
||||
|
||||
return dict(input_ids=batch_input_ids, labels=batch_labels)
|
||||
|
||||
|
127
omnilmm/utils.py
Normal file
@ -0,0 +1,127 @@
|
||||
import datetime
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
from omnilmm.constants import LOGDIR
|
||||
|
||||
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
||||
|
||||
handler = None
|
||||
|
||||
|
||||
def build_logger(logger_name, logger_filename):
|
||||
global handler
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
stdout_logger = logging.getLogger("stdout")
|
||||
stdout_logger.setLevel(logging.INFO)
|
||||
sl = StreamToLogger(stdout_logger, logging.INFO)
|
||||
sys.stdout = sl
|
||||
|
||||
stderr_logger = logging.getLogger("stderr")
|
||||
stderr_logger.setLevel(logging.ERROR)
|
||||
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||
sys.stderr = sl
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Add a file handler for all loggers
|
||||
if handler is None:
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
filename = os.path.join(LOGDIR, logger_filename)
|
||||
handler = logging.handlers.TimedRotatingFileHandler(
|
||||
filename, when='D', utc=True)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class StreamToLogger(object):
|
||||
"""
|
||||
Fake file-like stream object that redirects writes to a logger instance.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, log_level=logging.INFO):
|
||||
self.terminal = sys.stdout
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
self.linebuf = ''
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.terminal, attr)
|
||||
|
||||
def write(self, buf):
|
||||
temp_linebuf = self.linebuf + buf
|
||||
self.linebuf = ''
|
||||
for line in temp_linebuf.splitlines(True):
|
||||
# From the io.TextIOWrapper docs:
|
||||
# On output, if newline is None, any '\n' characters written
|
||||
# are translated to the system default line separator.
|
||||
# By default sys.stdout.write() expects '\n' newlines and then
|
||||
# translates them so this is still cross platform.
|
||||
if line[-1] == '\n':
|
||||
self.logger.log(self.log_level, line.rstrip())
|
||||
else:
|
||||
self.linebuf += line
|
||||
|
||||
def flush(self):
|
||||
if self.linebuf != '':
|
||||
self.logger.log(self.log_level, self.linebuf.rstrip())
|
||||
self.linebuf = ''
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
|
||||
def violates_moderation(text):
|
||||
"""
|
||||
Check whether the text violates OpenAI moderation API.
|
||||
"""
|
||||
url = "https://api.openai.com/v1/moderations"
|
||||
headers = {"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
||||
text = text.replace("\n", "")
|
||||
data = "{" + '"input": ' + f'"{text}"' + "}"
|
||||
data = data.encode("utf-8")
|
||||
try:
|
||||
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
||||
flagged = ret.json()["results"][0]["flagged"]
|
||||
except requests.exceptions.RequestException as e:
|
||||
flagged = False
|
||||
except KeyError as e:
|
||||
flagged = False
|
||||
|
||||
return flagged
|
||||
|
||||
|
||||
def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
31
requirements.txt
Normal file
@ -0,0 +1,31 @@
|
||||
packaging==23.2
|
||||
addict==2.4.0
|
||||
base_utils==1.0.14
|
||||
editdistance==0.6.2
|
||||
einops==0.7.0
|
||||
fairscale==0.4.0
|
||||
jsonlines==4.0.0
|
||||
markdown2==2.4.10
|
||||
matplotlib==3.7.4
|
||||
more_itertools==10.1.0
|
||||
nltk==3.8.1
|
||||
numpy==1.24.4
|
||||
opencv_python_headless==4.5.5.64
|
||||
openpyxl==3.1.2
|
||||
Pillow==10.1.0
|
||||
sacrebleu==2.3.2
|
||||
seaborn==0.13.0
|
||||
shortuuid==1.0.11
|
||||
spacy==3.7.2
|
||||
timm==0.9.10
|
||||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
tqdm==4.66.1
|
||||
protobuf==4.25.0
|
||||
transformers==4.36.0
|
||||
typing_extensions==4.8.0
|
||||
uvicorn==0.24.0.post1
|
||||
#xformers==0.0.22.post7
|
||||
#flash_attn==2.3.4
|
||||
sentencepiece==0.1.99
|
||||
accelerate==0.24.1
|