First commit

This commit is contained in:
yiranyyu 2024-02-01 14:45:00 +08:00
commit 24bb62bce8
31 changed files with 2978 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
*.bk
__pycache__
.DS_Store

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2024 OpenBMB
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

371
README.md Normal file
View File

@ -0,0 +1,371 @@
<div align="center">
<!-- <!-- <h1 style="color: #33A6B8; font-family: Helvetica"> OmniLMM </h1> -->
<img src="./assets/title-2.png" width="200em" ></img>
**Large multi-modal models for strong performance and efficient deployment**
<p align="center">
OmniLMM-3B <a href="https://huggingface.co/openbmb/MiniCPM-V/">🤗</a> <a href="http://120.92.209.146:80/">🤖</a> |
OmniLMM-12B <a href="https://huggingface.co/openbmb/OmniLMM-12B/">🤗</a> <a href="http://120.92.209.146:8081">🤖</a>
</p>
</div>
**OmniLMM** is a family of open-source large multimodal models (LMMs) adept at vision & language modeling. The model processes images and text inputs and delivers high-quality text outputs. We release two featured versions of OmniLMM that are targeted at **strong performance and efficient deployment**:
- **OmniLMM-12B**: Leading performance among comparable-sized models on multiple benchmarks.
- **OmniLMM-3B**: Frontier end device multi-modal conversation with promising performance.
[中文文档](./README_zh.md)
## Contents
- [OmniLMM-12B](#omnilmm-12b)
- [OmniLMM-3B](#omnilmm-3b)
- [Demo](#demo)
- [Install](#install)
- [Inference](#inference)
- [Model Zoo](#model-zoo)
## OmniLMM-12B
**OmniLMM-12B** is the most capable version. The model is built based on EVA02-5B and Zephyr-7B-β, connected with a perceiver resampler layer, and trained on multimodal data in a curriculum fashion. The model has three notable features:
- 🔥 **Strong Performance.**
OmniLMM-12B achieves **leading performance** among models with comparable sizes, surpassing established LMMs on multiple benchmarks (including MME, MMBench, SEED-Bench, etc). The model also **supports OCR capability** and endows **rich multimodal world knowledge**.
- 🏆 **Trustworthy Behavior.**
LMMs are known for suffering from hallucination, often generating text that is not factually grounded in images (e.g., faithfully describing non-existing objects in images). OmniLMM-12B is **the first state-of-the-art open-source LMM aligned via multimodal RLHF for trustworthy behavior** (using our recent [RLHF-V](https://rlhf-v.github.io/) technique). It **ranks #1** among open-source models on [MMHal-Bench](https://huggingface.co/datasets/Shengcao1006/MMHal-Bench), and **outperforms GPT-4V** on [Object HalBench](https://arxiv.org/abs/2312.00849).
- 🕹 **Real-time Multimodal Interaction.**
We combine the OmniLMM-12B and GPT-3.5 into a **real-time multimodal interactive assistant**. The assistant accepts video streams from the camera and speech streams from the microphone and emits speech output. While still primary, we find the model can **replicate some of the fun cases shown in the Gemini Demo video, without any video edition**.
### Evaluation
<table>
<thead>
<tr>
<th align="left">Model</th>
<th>Size</th>
<th>MME</th>
<th nowrap="nowrap">MMB dev (en)</th>
<th nowrap="nowrap" >MMMU val</th>
<th nowrap="nowrap" >MMHal-Bench</th>
<th nowrap="nowrap" >Object HalBench</th>
<th nowrap="nowrap" >SeedBench-I</th>
<th>MathVista</th>
<th nowrap="nowrap" >LLaVA Bench W</th>
</tr>
</thead>
<tbody align="center">
<tr>
<td align="left">GPT-4V†</td>
<td>-</td>
<td>1409</td>
<td>75.1 </td>
<td>56.8</td>
<td>3.53 / 70.8</td>
<td>86.4 / 92.7</td>
<td>71.6 </td>
<td>47.8 </td>
<td>93.1 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left">Qwen-VL-Plus†</td>
<td>-</td>
<td>1681</td>
<td>66.2 </td>
<td>45.2</td>
<td>- </td>
<td>- </td>
<td>65.7 </td>
<td>36.0 </td>
<td>73.7 </td>
</tr>
<tr>
<td align="left">Yi-VL 6B</td>
<td align="right">6.7B </td>
<td>- </td>
<td>68.2 </td>
<td>39.1 </td>
<td>- </td>
<td>- </td>
<td>66.1 </td>
<td>28.0 </td>
<td>39.9 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >Qwen-VL-Chat</td>
<td align="right">9.6B</td>
<td>1488</td>
<td>60.6 </td>
<td>35.9</td>
<td>2.93 / 59.4</td>
<td>56.2 / 80.0</td>
<td>64.8 </td>
<td>33.8 </td>
<td>67.7 </td>
</tr>
<tr>
<td align="left" >CogVLM</td>
<td align="right">17.4B</td>
<td>1438</td>
<td>63.7 </td>
<td>32.1 </td>
<td>2.68 / 52.1 </td>
<td>73.6 / 87.4 </td>
<td>68.8 </td>
<td>34.7 </td>
<td>73.9 </td>
</tr>
<tr>
<td align="left" >LLaVA 1.5</td>
<td align="right">13.6B </td>
<td>1531 </td>
<td>68.2 </td>
<td>36.4 </td>
<td>2.71 / 51.0 </td>
<td>53.7 / 77.4 </td>
<td>68.1 </td>
<td>26.4 </td>
<td>64.6 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" ><b>OmniLMM-12B</b></td>
<td align="right">11.6B </td>
<td>1637 </td>
<td>71.6 </td>
<td>40.7 </td>
<td>3.45 / 68.8 </td>
<td>90.3 / 95.5 </td>
<td>71.1 </td>
<td>34.9 </td>
<td>72.0 </td>
</tr>
</tbody>
</table>
<small>†: Proprietary models</small>
### Examples
<table align="center" >
<p align="center" >
<img src="assets/omnilmm-12b-examples.png" />
</p>
</table>
<div align="center" >
<video controls src="https://github.com/OpenBMB/OmniLMM/assets/157115220/c1fd3562-1ab1-4534-8139-79e9137b5398" type="video/mp4" />
</div>
## OmniLMM-3B
**OmniLMM-3B** (i.e., MiniCPM-V) is an efficient version with promising performance for deployment. The model is built based on SigLip-400M and [MiniCPM-2.4B](https://github.com/OpenBMB/MiniCPM/), connected by a perceiver resampler. Notable features of OmniLMM-3B include:
- ⚡️ **High Efficiency.**
OmniLMM-3B can be **efficiently deployed on most GPU cards and personal computers**, and **even on end devices such as mobile phones**. In terms of visual encoding, we compress the image representations into 64 tokens via a perceiver resampler, which is significantly fewer than other LMMs based on MLP architecture (typically > 512 tokens). This allows OmniLMM-3B to operate with **much less memory cost and higher speed during inference**.
- 🔥 **Promising Performance.**
OmniLMM-3B achieves **state-of-the-art performance** on multiple benchmarks (including MMMU, MME, and MMbech, etc) among models with comparable sizes, surpassing existing LMMs built on Phi-2. It even **achieves comparable or better performance than the 9.6B Qwen-VL-Chat**.
- 🙌 **Bilingual Support.**
OmniLMM-3B is **the first edge-deployable LMM supporting bilingual multimodal interaction in English and Chinese**. This is achieved by generalizing multimodal capabilities across languages, a technique from our ICLR 2024 spotlight [paper](https://arxiv.org/abs/2308.12038).
### Evaluation
<div align="center">
<table style="margin: 0px auto;">
<thead>
<tr>
<th align="left">Model</th>
<th>Size</th>
<th>MME</th>
<th nowrap="nowrap" >MMB dev (en)</th>
<th nowrap="nowrap" >MMB dev (zh)</th>
<th nowrap="nowrap" >MMMU val</th>
<th nowrap="nowrap" >CMMMU val</th>
</tr>
</thead>
<tbody align="center">
<tr>
<td align="left">LLaVA-Phi</td>
<td align="right">3B</td>
<td>1335</td>
<td>59.8</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left">MobileVLM</td>
<td align="right">3B</td>
<td>1289</td>
<td>59.6</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >Imp-v1</td>
<td align="right">3B</td>
<td>1434</td>
<td>66.5</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td align="left" >Qwen-VL-Chat</td>
<td align="right" >9.6B</td>
<td>1487</td>
<td>60.6 </td>
<td>56.7 </td>
<td>35.9 </td>
<td>30.7 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >CogVLM</td>
<td align="right">17.4B </td>
<td>1438 </td>
<td>63.7 </td>
<td>53.8 </td>
<td>32.1 </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" ><b>OmniLMM-3B</b></td>
<td align="right">3B </td>
<td>1452 </td>
<td>67.3 </td>
<td>61.9 </td>
<td>34.7 </td>
<td>32.1 </td>
</tr>
</tbody>
</table>
</div>
### Examples
OmniLLM-3B is the first LMM deloyed on end devices. The demo video is the raw screen recording without edition.
<table align="center" >
<p align="center" >
<img src="assets/Snake_cn_Mushroom_en.gif" width=36%/>
</p>
</table>
## Demo
Click here to try out the Demo of [OmniLMM-12B](http://120.92.209.146:8081) and [OmniLMM-3B](http://120.92.209.146:80).
## Install
1. Clone this repository and navigate to the source folder
```bash
git clone https://github.com/OpenBMB/OmniLMM.git
cd OmniLMM
```
2. Create conda environment
```Shell
conda create -n OmniLMM python=3.10 -y
conda activate OmniLMM
```
3. Install dependencies
```shell
pip install -r requirements.txt
```
## Inference
### Model Zoo
| Model | Description | Download Link |
|:----------------------|:-------------------|:---------------:|
| OmniLMM-12B | The most capable version with strong performance. | [🤗](https://huggingface.co/openbmb/OmniLMM-12B) &nbsp;&nbsp; <a url="https://modelscope.cn/models/OpenBMB/OmniLMM-12B/files"> <img src="./assets/modelscope_logo.png" width="20px"></img></a> |
| OmniLMM-3B | The efficient version for end device deployment. | [🤗](https://huggingface.co/openbmb/MiniCPM-V) &nbsp;&nbsp; <a url="https://modelscope.cn/models/OpenBMB/MiniCPM-V/files"> <img src="./assets/modelscope_logo.png" width="20px"></img></a> |
### Multi-turn Conversation
Please refer to the following codes to run `OmniLMM`.
<div align="center">
<img src="assets/COCO_test2015_000000262144.jpg" width="660px">
</div>
```python
from chat import OmniLMMChat, img2base64
chat_model = OmniLMMChat('openbmb/OmniLMM-12B') # or 'openbmb/MiniCPM-V'
im_64 = img2base64('./assets/COCO_test2015_000000262144.jpg')
# First round chat
msgs = [{"role": "user", "content": "What are the people doing?"}]
inputs = {"image": im_64, "question": json.dumps(msgs)}
answer = chat_model.process(inputs)
print(answer)
# Second round chat
# pass history context of multi-turn conversation
msgs.append({"role": "assistant", "content": answer})
msgs.append({"role": "user", "content": "Describe the image"})
inputs = {"image": im_64, "question": json.dumps(msgs)}
answer = chat_model.process(inputs)
print(answer)
```
We can obtain the following results:
```
"The people in the image are playing baseball. One person is pitching a ball, another one is swinging a bat to hit it, and there's also an umpire present who appears to be watching the game closely."
"The image depicts a baseball game in progress. A pitcher is throwing the ball, while another player is swinging his bat to hit it. An umpire can be seen observing the play closely."
```
## ✅ TODO
- [ ] Fine-tuning support
- [ ] Local Web-UI deployment
- [ ] Code release for real-time interactive assistant
## Model License
The code in this repo is released according to [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE)
The usage of OmniLMMs' parameters is subject to "[General Model License Agreement - Source Notes - Publicity Restrictions - Commercial License](https://github.com/OpenBMB/General-Model-License/blob/main/通用模型许可协议-来源说明-宣传限制-商业授权.md)"
The parameters are fully open to acedemic research
Please contact cpm@modelbest.cn to obtain a written authorization for commercial uses. Free commercial use is also allowed after registration.
## Statement
As LMMs, OmniLMMs generate contents by learning a large mount of multimodal corpora, but it cannot comprehend, express personal opinions or make value judgement. Anything generated by OmniLMMs does not represent the views and positions of the model developers
We will not be liable for any problems arising from the use of OmniLMM open source models, including but not limited to data security issues, risk of public opinion, or any risks and problems arising from the misdirection, misuse, dissemination or misuse of the model.
## 🏫 Institutions
This project is developed by the following institutions:
- <img src="assets/thunlp.png" width="28px"> [THUNLP](https://nlp.csai.tsinghua.edu.cn/)
- <img src="assets/modelbest.png" width="28px"> [ModelBest](https://modelbest.cn/)
- <img src="assets/zhihu.webp" width="28px"> [Zhihu](https://www.zhihu.com/ )

390
README_zh.md Normal file
View File

@ -0,0 +1,390 @@
<div align="center">
<!-- <!-- <h1 style="color: #33A6B8; font-family: Helvetica"> OmniLMM </h1> -->
<img src="./assets/title-2.png" width="200em" ></img>
**性能强大且部署高效的多模态大模型**
<p align="center">
OmniLMM-3B <a href="https://huggingface.co/openbmb/MiniCPM-V/">🤗</a> <a href="http://120.92.209.146:80/">🤖</a> |
OmniLMM-12B <a href="https://huggingface.co/openbmb/OmniLMM-12B/">🤗</a> <a href="http://120.92.209.146:8081">🤖</a>
</p>
</div>
**OmniLMM** 是一系列善于处理图文输入的开源多模态大模型LMMs。该系列模型接受图像和文本输入并提供高质量的文本输出。我们发布了两个版本的 OmniLMM旨在实现**强大的性能和高效的部署**
- **OmniLMM-12B**:相比同规模其他模型在多个基准测试中具有领先性能。
- **OmniLMM-3B**:可在终端设备上部署并具备先进的多模态对话能力。
[English Document](./README.md)
## 目录
- [OmniLMM-12B](#omnilmm-12b)
- [OmniLMM-3B](#omnilmm-3b)
- [体验](#demo)
- [安装](#install)
- [推理](#inference)
- [模型库](#model-zoo)
## OmniLMM-12B
**OmniLMM-12B** 是当前系列中性能最强大的版本。该模型使用一个感知重采样层连接 EVA02-5B 和 Zephyr-7B-β 来构建,采用了课程学习的方法在多模态数据上进行训练。该模型具有三个显著特征:
- 🔥 **卓越性能。**
OmniLMM-12B 相比其他同规模模型在多个基准测试中取得**领先的性能**(包括 MME、MMBench、SEED-Bench 等)。该模型还**支持OCR功能**,并掌握了**丰富的多模态世界知识**。
- 🏆 **可信行为。**
LMMs 的幻觉问题备受关注模型经常生成和图像中的事实不符的文本例如信誓旦旦地描述图片中并不存在的物体。OmniLMM-12B是 **第一个通过多模态 RLHF 对齐的最新开源 LMM 来实现可信行为**(通过我们最近提出的 [RLHF-V](https://rlhf-v.github.io/) 技术)。该模型在 [MMHal-Bench](https://huggingface.co/datasets/Shengcao1006/MMHal-Bench) 幻觉评测基准上位列开源模型中**第一**,并在 [Object HalBench](https://arxiv.org/abs/2312.00849) 中**超过了 GPT-4V**。
- 🕹 **实时多模态交互。**
我们将 OmniLMM-12B 和 GPT-3.5 结合成一个**实时多模态交互助手**。该助手接受来自相机的视频流和来自麦克风的语音流,并发出语音输出。虽然还处于初级阶段,但我们也发现该模型**无需视频编辑**就可以**复现出现在 Gemini 演示视频中的一些有趣例子**。
### 性能评估
<table>
<thead>
<tr>
<th align="left">Model</th>
<th>Size</th>
<th>MME</th>
<th nowrap="nowrap">MMB dev (en)</th>
<th nowrap="nowrap" >MMMU val</th>
<th nowrap="nowrap" >MMHal-Bench</th>
<th nowrap="nowrap" >Object HalBench</th>
<th nowrap="nowrap" >SeedBench-I</th>
<th>MathVista</th>
<th nowrap="nowrap" >LLaVA Bench W</th>
</tr>
</thead>
<tbody align="center">
<tr>
<td align="left">GPT-4V†</td>
<td>-</td>
<td>1409</td>
<td>75.1 </td>
<td>56.8</td>
<td>3.53 / 70.8</td>
<td>86.4 / 92.7</td>
<td>71.6 </td>
<td>47.8 </td>
<td>93.1 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left">Qwen-VL-Plus†</td>
<td>-</td>
<td>1681</td>
<td>66.2 </td>
<td>45.2</td>
<td>- </td>
<td>- </td>
<td>65.7 </td>
<td>36.0 </td>
<td>73.7 </td>
</tr>
<tr>
<td align="left">Yi-VL 6B</td>
<td align="right">6.7B </td>
<td>- </td>
<td>68.2 </td>
<td>39.1 </td>
<td>- </td>
<td>- </td>
<td>66.1 </td>
<td>28.0 </td>
<td>39.9 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >Qwen-VL-Chat</td>
<td align="right">9.6B</td>
<td>1488</td>
<td>60.6 </td>
<td>35.9</td>
<td>2.93 / 59.4</td>
<td>56.2 / 80.0</td>
<td>64.8 </td>
<td>33.8 </td>
<td>67.7 </td>
</tr>
<tr>
<td align="left" >CogVLM</td>
<td align="right">17.4B</td>
<td>1438</td>
<td>63.7 </td>
<td>32.1 </td>
<td>2.68 / 52.1 </td>
<td>73.6 / 87.4 </td>
<td>68.8 </td>
<td>34.7 </td>
<td>73.9 </td>
</tr>
<tr>
<td align="left" >LLaVA 1.5</td>
<td align="right">13.6B </td>
<td>1531 </td>
<td>68.2 </td>
<td>36.4 </td>
<td>2.71 / 51.0 </td>
<td>53.7 / 77.4 </td>
<td>68.1 </td>
<td>26.4 </td>
<td>64.6 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" ><b>OmniLMM-12B</b></td>
<td align="right">11.6B </td>
<td>1637 </td>
<td>71.6 </td>
<td>40.7 </td>
<td>3.45 / 68.8 </td>
<td>90.3 / 95.5 </td>
<td>71.1 </td>
<td>34.9 </td>
<td>72.0 </td>
</tr>
</tbody>
</table>
<small>†: 闭源模型</small>
## OmniLMM-3B
**OmniLMM-3B**(即 MiniCPM-V是一种我们的高效率版本模型可用于终端机器上的部署。该模型基于 SigLip-400M 和 MiniCPM-2.4B 构建通过感知器重采样器连接。OmniLMM-3B的显著特点包括
- ⚡️ **高效率。**
OmniLMM-3B 可以**高效地部署在大多数GPU卡和个人电脑上**,甚至**在移动手机等终端设备上**。在视觉编码方面,我们通过感知器重采样器将图像表示压缩为 64 个 token远远少于基于MLP架构的其他LMMs通常大于 512 个 token。这使得 OmniLMM-3B 在推理期间**内存成本更低且速度更快**。
- 🔥 **优秀的性能。**
OmniLMM-3B 在与相似大小模型相比的多个基准测试中实现了**最先进的性能**,超过了基于 Phi-2构建的现有 LMMs。它甚至**实现了与9.6B Qwen-VL-Chat 相媲美或更好的性能**。
- 🙌 **双语支持。**
OmniLMM-3B 是**第一个支持英语和中文双语多模态交互的终端可部署 LMM**。这是通过跨语言泛化多模态能力实现的,这是我们 ICLR 2024 [spotlight 论文](https://arxiv.org/abs/2308.12038)中的一项技术。
### Evaluation
<div align="center">
<table style="margin: 0px auto;">
<thead>
<tr>
<th align="left">Model</th>
<th>Size</th>
<th>MME</th>
<th nowrap="nowrap" >MMB dev (en)</th>
<th nowrap="nowrap" >MMB dev (zh)</th>
<th nowrap="nowrap" >MMMU val</th>
<th nowrap="nowrap" >CMMMU val</th>
</tr>
</thead>
<tbody align="center">
<tr>
<td align="left">LLaVA-Phi</td>
<td align="right">3B</td>
<td>1335</td>
<td>59.8</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left">MobileVLM</td>
<td align="right">3B</td>
<td>1289</td>
<td>59.6</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >Imp-v1</td>
<td align="right">3B</td>
<td>1434</td>
<td>66.5</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td align="left" >Qwen-VL-Chat</td>
<td align="right" >9.6B</td>
<td>1487</td>
<td>60.6 </td>
<td>56.7 </td>
<td>35.9 </td>
<td>30.7 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >CogVLM</td>
<td align="right">17.4B </td>
<td>1438 </td>
<td>63.7 </td>
<td>53.8 </td>
<td>32.1 </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" ><b>OmniLMM-3B</b></td>
<td align="right">3B </td>
<td>1452 </td>
<td>67.3 </td>
<td>61.9 </td>
<td>34.7 </td>
<td>32.1 </td>
</tr>
</tbody>
</table>
</div>
### 样例展示
<table align="center" >
<p align="center" >
<img src="assets/Snake_cn_Mushroom_en.gif" width=48%/>
</p>
</table>
## 体验
你可以通过以下链接尝试使用我们的网页端推理服务: [OmniLMM-12B](http://120.92.209.146:8081) [OmniLMM-3B](http://120.92.209.146:80).
## 安装
1. Clone this repository and navigate to the source folder
```bash
git clone https://github.com/OpenBMB/OmniLMM.git
cd OmniLMM
```
2. Create conda environment
```Shell
conda create -n OmniLMM python=3.10 -y
conda activate OmniLMM
```
3. Install dependencies
```shell
pip install -r requirements.txt
```
## 推理
### 模型库
| 模型 | 简介 | 下载链接 |
|:----------------------|:-------------------|:---------------:|
| OmniLMM-12B | 更强大的性能表现 | [🤗](https://huggingface.co/openbmb/OmniLMM-12B) &nbsp;&nbsp; <a url="https://modelscope.cn/models/OpenBMB/OmniLMM-12B/files"> <img src="./assets/modelscope_logo.png" width="20px"></img></a> |
| OmniLMM-3B | 支持终端设备上的高效部署,性能优秀 | [🤗](https://huggingface.co/openbmb/MiniCPM-V) &nbsp;&nbsp; <a url="https://modelscope.cn/models/OpenBMB/MiniCPM-V/files"> <img src="./assets/modelscope_logo.png" width="20px"></img></a> |
### 多轮对话
请参考以下代码运行 `OmniLMM` 的推理服务。
<div align="center">
<img src="assets/COCO_test2015_000000262144.jpg" width="660px">
</div>
##### OmniLMM-12B
```python
from chat import OmniLMMChat, img2base64
chat_model = OmniLMMChat('openbmb/OmniLMM-12B')
im_64 = img2base64('./assets/COCO_test2015_000000262144.jpg')
# First round chat
msgs = [{"role": "user", "content": "What are the people doing?"}]
inputs = {"image": im_64, "question": json.dumps(msgs)}
answer = chat_model.process(inputs)
print(answer)
# Second round chat
# pass history context of multi-turn conversation
msgs.append({"role": "assistant", "content": answer})
msgs.append({"role": "user", "content": "Describe the image"})
inputs = {"image": im_64, "question": json.dumps(msgs)}
answer = chat_model.process(inputs)
print(answer)
```
We can obtain the following results:
```
"The people in the image are playing baseball. One person is pitching a ball, another one is swinging a bat to hit it, and there's also an umpire present who appears to be watching the game closely."
"The image depicts a baseball game in progress. A pitcher is throwing the ball, while another player is swinging his bat to hit it. An umpire can be seen observing the play closely."
```
##### OmniLMM-3B
```python
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
model_path='openbmb/MiniCPM-V'
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval().cuda()
image = Image.open('./assets/COCO_test2015_000000262144.jpg').convert('RGB')
question = '请描述一下该图像'
res, context, _ = model.chat(
image=image,
question=question,
context=None,
tokenizer=tokenizer,
sampling=True,
temperature=0.7
)
print(res)
```
## ✅ 未来计划
- [ ] 支持模型微调
- [ ] 本地可视化部署
- [ ] 实时多模态交互代码开源
## 模型协议
本仓库中代码依照 Apache-2.0 协议开源
OmniLMMs 模型权重的使用则需要遵循 “[通用模型许可协议-来源说明-宣传限制-商业授权](https://github.com/OpenBMB/General-Model-License/blob/main/通用模型许可协议-来源说明-宣传限制-商业授权.md)”。
OmniLMMs 模型权重对学术研究完全开放。
如需将模型用于商业用途,请联系 cpm@modelbest.cn 来获取书面授权,在登记后亦允许免费商业使用。
## 声明
作为多模态大模型OmniLMMs 通过学习大量的多模态语料来生成内容,但它无法理解、表达个人观点或价值判断,它所输出的任何内容都不代表模型开发者的观点和立场。
因此用户在使用 OmniLMMs 生成的内容时,应自行负责对其进行评估和验证。
如果由于使用 OmniLMMs 开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
## 🏫 机构
本项目由以下机构共同开发:
- <img src="assets/thunlp.png" width="28px"> [清华大学自然语言处理实验室](https://nlp.csai.tsinghua.edu.cn/)
- <img src="assets/modelbest.png" width="28px"> [面壁智能](https://modelbest.cn/)
- <img src="assets/zhihu.webp" width="28px"> [知乎](https://www.zhihu.com/ )

Binary file not shown.

After

Width:  |  Height:  |  Size: 188 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 MiB

BIN
assets/demo_video.mp4 Normal file

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 819 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

BIN
assets/gif_cases/蛇_cn.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 MiB

BIN
assets/modelbest.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

BIN
assets/modelscope_logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.1 MiB

BIN
assets/thunlp.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

BIN
assets/title-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

BIN
assets/title-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

BIN
assets/title-3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

BIN
assets/zhihu.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

194
chat.py Normal file
View File

@ -0,0 +1,194 @@
import os
import torch
import json
from PIL import Image
import base64
import io
from accelerate import load_checkpoint_and_dispatch, init_empty_weights
from transformers import AutoTokenizer, AutoModel
from omnilmm.utils import disable_torch_init
from omnilmm.model.omnilmm import OmniLMMForCausalLM
from omnilmm.model.utils import build_transform
from omnilmm.train.train_utils import omni_preprocess
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def init_omni_lmm(model_path):
torch.backends.cuda.matmul.allow_tf32 = True
disable_torch_init()
model_name = os.path.expanduser(model_path)
print(f'Load omni_lmm model and tokenizer from {model_name}')
tokenizer = AutoTokenizer.from_pretrained(
model_name, model_max_length=2048)
if False:
# model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
with init_empty_weights():
model = OmniLMMForCausalLM.from_pretrained(model_name, tune_clip=True, torch_dtype=torch.bfloat16)
model = load_checkpoint_and_dispatch(model, model_name, dtype=torch.bfloat16,
device_map="auto", no_split_module_classes=['Eva','MistralDecoderLayer', 'ModuleList', 'Resampler']
)
else:
model = OmniLMMForCausalLM.from_pretrained(
model_name, tune_clip=True, torch_dtype=torch.bfloat16
).to(device='cuda', dtype=torch.bfloat16)
image_processor = build_transform(
is_train=False, input_size=model.model.config.image_size, std_mode='OPENAI_CLIP')
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
assert mm_use_im_start_end
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN], special_tokens=True)
vision_config = model.model.vision_config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = model.model.config.num_query
return model, image_processor, image_token_len, tokenizer
def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
if '<image>' in question_text[0]['content']:
question_text[0]['content'] = question_text[0]['content'].replace(
'<image>', im_st_token + im_patch_token * image_token_len + im_ed_token)
else:
question_text[0]['content'] = im_st_token + im_patch_token * \
image_token_len + im_ed_token + '\n' + question_text[0]['content']
return question_text
def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
question = expand_question_into_multimodal(
question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
conversation = question
data_dict = omni_preprocess(sources=[conversation],
tokenizer=tokenizer,
generation=True)
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
return data_dict
class OmniLMM12B:
def __init__(self, model_path) -> None:
model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
self.model = model
self.image_token_len = image_token_len
self.image_transform = img_processor
self.tokenizer = tokenizer
self.model.eval()
def decode(self, image, input_ids):
with torch.inference_mode():
output = self.model.generate_vllm(
input_ids=input_ids.unsqueeze(0).cuda(),
images=image.unsqueeze(0).half().cuda(),
temperature=0.6,
max_new_tokens=1024,
# num_beams=num_beams,
do_sample=True,
output_scores=True,
return_dict_in_generate=True,
repetition_penalty=1.1,
top_k=30,
top_p=0.9,
)
response = self.tokenizer.decode(
output.sequences[0], skip_special_tokens=True)
response = response.strip()
return response
def chat(self, input):
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input['question'])
input_ids = wrap_question_for_omni_lmm(
msgs, self.image_token_len, self.tokenizer)['input_ids']
input_ids = torch.as_tensor(input_ids)
#print('input_ids', input_ids)
image = self.image_transform(image)
out = self.decode(image, input_ids)
return out
def img2base64(file_name):
with open(file_name, 'rb') as f:
encoded_string = base64.b64encode(f.read())
return encoded_string
class OmniLMM3B:
def __init__(self, model_path) -> None:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model.eval().cuda()
def chat(self, input):
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input['question'])
answer, context, _ = self.model.chat(
image=image,
msgs=msgs,
context=None,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7
)
return answer
class OmniLMMChat:
def __init__(self, model_path) -> None:
if '12B' in model_path:
self.model = OmniLMM12B(model_path)
else:
self.model = OmniLMM3B(model_path)
def chat(self, input):
return self.model.chat(input)
if __name__ == '__main__':
model_path = 'openbmb/OmniLMM-12B'
chat_model = OmniLMMChat(model_path)
im_64 = img2base64('./assets/COCO_test2015_000000262144.jpg')
# first round chat
msgs = [{"role": "user", "content": "What are the people doing?"}]
input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
answer = chat_model.chat(input)
print(msgs[-1]["content"]+'\n', answer)
# second round chat
msgs.append({"role": "assistant", "content": answer})
msgs.append({"role": "user", "content": "Describe the image"})
input = {"image": im_64,"question": json.dumps(msgs, ensure_ascii=True)}
answer = chat_model.chat(input)
print(msgs[-1]["content"]+'\n', answer)

0
omnilmm/__init__.py Normal file
View File

4
omnilmm/constants.py Normal file
View File

@ -0,0 +1,4 @@
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."

320
omnilmm/conversation.py Normal file
View File

@ -0,0 +1,320 @@
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(
pil_img.mode, (width, width), background_color)
result.paste(
pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(
pil_img.mode, (height, height), background_color)
result.paste(
pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode == "Crop":
pass
elif image_process_mode == "Resize":
image = image.resize((224, 224))
else:
raise ValueError(
f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(
min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(
buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(
min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
# image = image.resize((224, 224))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(
buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = msg.replace('<image>', img_str)
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "Give three tips for staying healthy."),
("Assistant",
"Sure, here are three tips for staying healthy:\n"
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
"activities at least two days per week.\n"
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
"vegetables, whole grains, lean proteins, and healthy fats can help support "
"your overall health. Try to limit your intake of processed and high-sugar foods, "
"and aim to drink plenty of water throughout the day.\n"
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
"help improve the quality of your sleep.")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_v1_2 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_bair_v1 = Conversation(
system="BEGINNING OF CONVERSATION:",
roles=("USER", "GPT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
simple_conv = Conversation(
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
"You are designed to assist human with a variety of tasks using natural language."
"Follow the instructions carefully.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hi!"),
("Assistant", "Hi there! How can I help you today?\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
simple_conv_multimodal = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
simple_conv_legacy = Conversation(
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
"You are designed to assist human with a variety of tasks using natural language."
"Follow the instructions carefully.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hi!\n\n### Response:"),
("Assistant", "Hi there! How can I help you today?\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_llava_v1 = Conversation(
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"Follow the instructions carefully and explain your answers in detail.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
default_conversation = conv_v1_2
conv_templates = {
"default": conv_v1_2,
"simple": simple_conv,
"simple_legacy": simple_conv_legacy,
"multimodal": simple_conv_multimodal,
"llava_v1": conv_llava_v1,
# fastchat
"v1": conv_v1_2,
"bair_v1": conv_bair_v1,
"vicuna_v1_1": conv_vicuna_v1_1,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())

View File

@ -0,0 +1 @@
from .omnilmm import OmniLMMForCausalLM

457
omnilmm/model/omnilmm.py Normal file
View File

@ -0,0 +1,457 @@
import gc
import math
import timm
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from typing import List, Optional, Tuple, Union
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import MistralForCausalLM, MistralModel, MistralConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from omnilmm.model.utils import build_transform
from omnilmm.model.resampler import Resampler
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
class OmniLMMConfig(MistralConfig):
model_type = "omnilmm"
class Identity(torch.nn.Identity):
def forward(self, input: Tensor, **kwargs) -> Tensor:
return super().forward(input)
def create_vision_module(config):
vision_tower = timm.create_model('eva02_enormous_patch14_clip_224.laion2b_plus',
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True)
if isinstance(vision_tower, timm.models.VisionTransformer):
if vision_tower.attn_pool is not None:
vision_tower.attn_pool = Identity()
# use 2nd last layer's output
vision_tower.blocks[-1] = Identity()
embed_dim = config.hidden_size
resampler = Resampler(
grid_size=int(math.sqrt(config.num_query)),
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_tower.embed_dim,
)
return vision_tower, resampler
class OmniLMMModel(MistralModel):
config_class = OmniLMMConfig
def __init__(self, config: OmniLMMConfig, mm_vision_tower=None, mm_hidden_size=None, tune_clip=True):
super(OmniLMMModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
vision_tower, resampler = create_vision_module(config)
print(__file__, 'skip loading vision tower weights')
# HACK: for FSDP
self.vision_tower = [vision_tower]
self.resampler = resampler
if tune_clip:
self.vision_tower = self.vision_tower[0]
self.vision_config = lambda x: None
def initialize_vision_modules(self, vision_tower, no_randaug, num_query, image_size, tune_clip=False):
self.config.mm_vision_tower = vision_tower
self.config.use_mm_proj = True
self.config.num_query = num_query
self.config.image_size = image_size
if not hasattr(self, 'vision_tower'):
vision_tower, resampler = create_vision_module(self.config)
state_dict = torch.load(
'/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt')
vision_tower.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
else:
if isinstance(self.vision_tower, list):
vision_tower = self.vision_tower[0]
else:
vision_tower = self.vision_tower
resampler = self.resampler
self.vision_tower = vision_tower if tune_clip else [vision_tower]
self.resampler = resampler
train_img_transform = build_transform(
is_train=True, randaug=not no_randaug, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
eval_img_transform = build_transform(
is_train=False, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
return dict(
image_processor=(train_img_transform, eval_img_transform),
image_token_len=num_query,
vision_config=self.vision_config
)
def get_vision_embedding(self, pixel_values):
if isinstance(self.vision_tower, list):
vision_tower = self.vision_tower[0] # HACK: for FSDP
else:
vision_tower = self.vision_tower
dtype = vision_tower.pos_embed.data.dtype
vision_embedding = vision_tower.forward_features(
pixel_values.type(dtype))
if hasattr(vision_tower, 'num_prefix_tokens') and vision_tower.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:,
vision_tower.num_prefix_tokens:]
res = self.resampler(vision_embedding)
return res
def get_vllm_embedding(self, data):
if 'vision_hidden_states' not in data:
pixel_values_list = data['pixel_values']
vision_hidden_states = []
for pixel_values in pixel_values_list:
if len(pixel_values) > 0:
vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))[0])
else:
vision_hidden_states.append([])
else:
vision_hidden_states = data['vision_hidden_states']
#vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
inputs_embeds = self.embed_tokens(data['input_ids'])
vision_hidden_states = [i.type(inputs_embeds.dtype)
if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
]
# HACK: replace back original embeddings for LLaVA pretraining
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
new_input_embeds = []
cur_image_idx = 0
for cur_input_ids, cur_input_embeds in zip(data['input_ids'], inputs_embeds):
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if self.vision_config.use_im_start_end:
cur_image_features = vision_hidden_states[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
raise ValueError(
"The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(
cur_input_ids == self.vision_config.im_start_token)[0]
for image_start_token_pos in image_start_tokens:
cur_image_features = vision_hidden_states[cur_image_idx].to(
device=cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
raise ValueError(
"The image end token should follow the image start token.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
else:
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
cur_image_idx += 1
new_input_embeds.append(cur_new_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return inputs_embeds, vision_hidden_states
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaVA pretraining
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
if inputs_embeds is None and past_key_values is None:
inputs_embeds = self.embed_tokens(input_ids)
vision_tower = getattr(self, 'vision_tower', None)
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[
0]
image_features.append(image_forward_out)
else:
image_features = self.get_vision_embedding(images)
dummy_image_features = torch.zeros(
self.config.num_query,
self.config.hidden_size,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype)
new_input_embeds = []
cur_image_idx = 0
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + \
(0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if self.vision_config.use_im_start_end:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
raise ValueError(
"The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(
cur_input_ids == self.vision_config.im_start_token)[0]
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx].to(
device=cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
raise ValueError(
"The image end token should follow the image start token.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
else:
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
cur_image_idx += 1
new_input_embeds.append(cur_new_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
input_ids = None
return super(OmniLMMModel, self).forward(
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs
)
class OmniLMMForCausalLM(MistralForCausalLM):
config_class = OmniLMMConfig
def __init__(self, config, mm_vision_tower=None, tune_clip=True):
super(MistralForCausalLM, self).__init__(config)
self.model = OmniLMMModel(
config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
# print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
# print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
images=images,
**kwargs
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# TODO could be removed for generate_vllm()
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
def generate_vllm(
self,
input_ids: torch.LongTensor = None,
images: Optional[torch.FloatTensor] = None,
vision_hidden_states=None,
return_vision_hidden_states=False,
**kwargs
):
model_inputs = {'input_ids': input_ids}
if vision_hidden_states is None:
model_inputs['pixel_values'] = images
else:
model_inputs['vision_hidden_states'] = vision_hidden_states
with torch.inference_mode():
inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(model_inputs)
result = self.generate(
inputs_embeds=inputs_embeds,
**kwargs
)
if return_vision_hidden_states:
return result, vision_hidden_states
return result
def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
tune_mm_mlp_adapter=False):
self.model.vision_config.use_im_start_end = mm_use_im_start_end
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
self.model.vision_config.im_start_token, self.model.vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
# for new sft data
num_new_tokens = tokenizer.add_tokens(
['<box>', '</box>', '<ref>', '</ref>', '<quad>', '</quad>'], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if tune_mm_mlp_adapter:
self.model.orig_embeds_params = [
self.get_input_embeddings().weight.data.clone().to(device=device)]
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
print(f'Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}', flush=True)
# exit()
AutoConfig.register("omnilmm", OmniLMMConfig)
AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)

171
omnilmm/model/resampler.py Normal file
View File

@ -0,0 +1,171 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import math
import requests
from io import BytesIO
from functools import partial
from PIL import Image
from typing import Callable, Optional, Sequence, Tuple, List, Union
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate(
[np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6)
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(
embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter(
(embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
# print((self._repeat(q, N) + self.pos_embed.unsqueeze(1)).dtype, (x + pos_embed.unsqueeze(1)).dtype, x.dtype)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
x = out.permute(1, 0, 2)
x = self.ln_post(x)
x = x @ self.proj
return x
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)

555
omnilmm/model/utils.py Normal file
View File

@ -0,0 +1,555 @@
from torchvision import transforms
from timm.data.transforms import RandomResizedCropAndInterpolation
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from transformers import AutoConfig
from PIL import Image
from io import BytesIO
import torch.distributed as dist
import numpy as np
import pickle
import base64
import cv2
import os
import torch
from transformers import AutoConfig, StoppingCriteria
try:
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
except ImportError:
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'llava' in config and cfg.model_type != 'llava':
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input(
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(
output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'llava' in config and cfg.model_type != 'llava':
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input(
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
# aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
'''
same output as PIL.ImageOps.autocontrast
'''
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
table = np.arange(n_bins) * scale - low * scale
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
'''
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
'''
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0:
return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
'''
like PIL, rotate by degree, not radians
'''
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
'''
same output as PIL.ImageOps.posterize
'''
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
'''
same output as PIL.ImageEnhance.Color
'''
# implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = (
np.float32([
[0.886, -0.114, -0.114],
[-0.587, 0.413, -0.587],
[-0.299, -0.299, 0.701]]) * factor
+ np.float32([[0.114], [0.587], [0.299]])
)
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = np.array([(
el - mean) * factor + mean
for el in range(256)
]).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def brightness_func(img, factor):
'''
same output as PIL.ImageEnhance.Contrast
'''
table = (np.arange(256, dtype=np.float32) *
factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
'''
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
'''
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * \
(out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def posterize_func(img, bits):
'''
same output as PIL.ImageOps.posterize
'''
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
# level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level, )
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level, )
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
'Identity': identity_func,
'AutoContrast': autocontrast_func,
'Equalize': equalize_func,
'Rotate': rotate_func,
'Solarize': solarize_func,
'Color': color_func,
'Contrast': contrast_func,
'Brightness': brightness_func,
'Sharpness': sharpness_func,
'ShearX': shear_x_func,
'TranslateX': translate_x_func,
'TranslateY': translate_y_func,
'Posterize': posterize_func,
'ShearY': shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
'Identity': none_level_to_args,
'AutoContrast': none_level_to_args,
'Equalize': none_level_to_args,
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
'Solarize': solarize_level_to_args(MAX_LEVEL),
'Color': enhance_level_to_args(MAX_LEVEL),
'Contrast': enhance_level_to_args(MAX_LEVEL),
'Brightness': enhance_level_to_args(MAX_LEVEL),
'Sharpness': enhance_level_to_args(MAX_LEVEL),
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
'TranslateX': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'TranslateY': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'Posterize': posterize_level_to_args(MAX_LEVEL),
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic', std_mode='IMAGENET_INCEPTION'):
if std_mode == 'IMAGENET_INCEPTION':
mean = IMAGENET_INCEPTION_MEAN
std = IMAGENET_INCEPTION_STD
elif std_mode == 'OPENAI_CLIP':
mean = OPENAI_CLIP_MEAN
std = OPENAI_CLIP_STD
else:
raise NotImplementedError
if is_train:
crop_scale = float(os.environ.get('TRAIN_CROP_SCALE', 0.9999))
t = [
RandomResizedCropAndInterpolation(
input_size, scale=(crop_scale, 1.0), interpolation='bicubic'),
# transforms.RandomHorizontalFlip(),
]
if randaug and os.environ.get('TRAIN_DO_AUG', 'False') == 'True':
print(f'@@@@@ Do random aug during training', flush=True)
t.append(
RandomAugment(
2, 7, isPIL=True,
augs=[
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
]))
else:
print(f'@@@@@ Skip random aug during training', flush=True)
t += [
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
t = transforms.Compose(t)
else:
t = transforms.Compose([
transforms.Resize((input_size, input_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
return t
def img2b64(img_path):
img = Image.open(img_path) # path to file
img_buffer = BytesIO()
img.save(img_buffer, format=img.format)
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data) # bytes
base64_str = base64_str.decode("utf-8") # str
return base64_str
def str2b64(str):
return base64.b64encode(str.encode('utf-8')).decode('utf-8')
def b642str(b64):
return base64.b64decode(b64).decode('utf-8')
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def mean(lst):
return sum(lst) / len(lst)
def stop_gradient_by_name(name: str):
def apply_fn(module):
if hasattr(module, name):
getattr(module, name).requires_grad_(False)
return apply_fn

View File

@ -0,0 +1,153 @@
import os
import gc
import copy
import time
import torch
import warnings
import transformers
import numpy as np
from typing import Dict, Optional, Sequence
from omnilmm import conversation as conversation_lib
IGNORE_INDEX = -100
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def omni_preprocess(sources,
tokenizer: transformers.PreTrainedTokenizer,
generation=False):
system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
ignore_index = -100
response_template = '\n<|assistant|>\n'
instruction_template = '\n<|user|>\n'
response_token_ids = tokenizer.encode(
response_template, add_special_tokens=False)
instruction_token_ids = tokenizer.encode(
instruction_template, add_special_tokens=False)
batch_input_ids = []
batch_labels = []
for i in range(len(sources)):
new_source = []
prev_role = 'unexpect'
for conv_turn in sources[i]:
role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']
role = 'user' if role == 'human' else role
role = 'assistant' if role == 'gpt' else role
assert role in ['user', 'assistant']
assert role != prev_role, f'role={role}, prev_role={prev_role}'
prev_role = role
new_turn = {
'role': role,
'content': content
}
new_source.append(new_turn)
if new_source[0]['role'] != 'system':
new_source.insert(0, {'role': 'system', 'content': system_content})
# TODO: this automatically add '\n' to the end
res_text = tokenizer.apply_chat_template(
new_source, tokenize=False, add_generation_prompt=generation)
if not generation:
res_text = res_text.strip()
conversations_tokenized = _tokenize_fn([res_text], tokenizer)
res_input_ids = conversations_tokenized["input_ids"][0]
# since labels and input_ids are reference towards the same object
res_labels = copy.deepcopy(conversations_tokenized["labels"][0])
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
# find the indexes of the start of a response.
if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
response_token_ids)].tolist()
):
response_token_ids_idxs.append(
assistant_idx + len(response_token_ids))
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{response_template}` in the "
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
f'Raw text is @===>{res_text}<===@'
f'Raw source is @===>{new_source}<===@'
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
res_labels[:] = ignore_index
human_token_ids = instruction_token_ids
for human_idx in np.where(res_labels == human_token_ids[0])[0]:
# find the indexes of the start of a human answer.
if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
warnings.warn(
f"Could not find instruction key `{instruction_template}` in the "
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
f'Raw text is @===>{res_text}<===@'
f'Raw source is @===>{new_source}<===@'
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
res_labels[:] = ignore_index
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
# Make pytorch loss function ignore all non response tokens
if idx != 0:
res_labels[start:end] = ignore_index
else:
res_labels[:end] = ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
res_labels[human_token_ids_idxs[-1]:] = ignore_index
batch_input_ids.append(res_input_ids)
batch_labels.append(res_labels)
return dict(input_ids=batch_input_ids, labels=batch_labels)

127
omnilmm/utils.py Normal file
View File

@ -0,0 +1,127 @@
import datetime
import logging
import logging.handlers
import os
import sys
import requests
from omnilmm.constants import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ''
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"

31
requirements.txt Normal file
View File

@ -0,0 +1,31 @@
packaging==23.2
addict==2.4.0
base_utils==1.0.14
editdistance==0.6.2
einops==0.7.0
fairscale==0.4.0
jsonlines==4.0.0
markdown2==2.4.10
matplotlib==3.7.4
more_itertools==10.1.0
nltk==3.8.1
numpy==1.24.4
opencv_python_headless==4.5.5.64
openpyxl==3.1.2
Pillow==10.1.0
sacrebleu==2.3.2
seaborn==0.13.0
shortuuid==1.0.11
spacy==3.7.2
timm==0.9.10
torch==2.0.1
torchvision==0.15.2
tqdm==4.66.1
protobuf==4.25.0
transformers==4.36.0
typing_extensions==4.8.0
uvicorn==0.24.0.post1
#xformers==0.0.22.post7
#flash_attn==2.3.4
sentencepiece==0.1.99
accelerate==0.24.1