From b7ea34c3ac11d00f635f1942464b86c0e8ce60c1 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 13 Mar 2024 10:43:42 +0800 Subject: [PATCH 01/21] Split the language model register tab panel into the separated js file. --- .../web/ui/src/scenes/register_model/index.js | 600 +---------------- .../register_model/register_language.js | 620 ++++++++++++++++++ 2 files changed, 623 insertions(+), 597 deletions(-) create mode 100644 xinference/web/ui/src/scenes/register_model/register_language.js diff --git a/xinference/web/ui/src/scenes/register_model/index.js b/xinference/web/ui/src/scenes/register_model/index.js index de3c04dc43..d211c223c9 100644 --- a/xinference/web/ui/src/scenes/register_model/index.js +++ b/xinference/web/ui/src/scenes/register_model/index.js @@ -1,322 +1,34 @@ import { TabContext, TabList, TabPanel } from '@mui/lab' import { Box, - Checkbox, - FormControl, - FormControlLabel, - FormHelperText, - Radio, - RadioGroup, Tab, } from '@mui/material' -import Alert from '@mui/material/Alert' -import AlertTitle from '@mui/material/AlertTitle' -import Button from '@mui/material/Button' -import TextField from '@mui/material/TextField' -import React, { useContext, useEffect, useState } from 'react' +import React, { useEffect } from 'react' import { useCookies } from 'react-cookie' import { useNavigate } from 'react-router-dom' -import { ApiContext } from '../../components/apiContext' import ErrorMessageSnackBar from '../../components/errorMessageSnackBar' -import fetcher from '../../components/fetcher' import Title from '../../components/Title' -import { useMode } from '../../theme' import RegisterEmbeddingModel from './register_embedding' +import RegisterLanguageModel from './register_language' import RegisterRerankModel from './register_rerank' -const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } -const SUPPORTED_FEATURES = ['Generate', 'Chat'] - -// Convert dictionary of supported languages into list -const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) const RegisterModel = () => { - const ERROR_COLOR = useMode() - const endPoint = useContext(ApiContext).endPoint - const { setErrorMsg } = useContext(ApiContext) - const [successMsg, setSuccessMsg] = useState('') - const [modelFormat, setModelFormat] = useState('pytorch') - const [modelSize, setModelSize] = useState(7) - const [modelUri, setModelUri] = useState('/path/to/llama-2') - const [quantization, setQuantization] = useState('') - const [formData, setFormData] = useState({ - version: 1, - context_length: 2048, - model_name: 'custom-llama-2', - model_lang: ['en'], - model_ability: ['generate'], - model_description: 'This is a custom model description.', - model_family: '', - model_specs: [], - prompt_style: undefined, - }) - const [promptStyles, setPromptStyles] = useState([]) - const [family, setFamily] = useState({ - chat: [], - generate: [], - }) - const [familyLabel, setFamilyLabel] = useState('') const [tabValue, setTabValue] = React.useState('1') const [cookie] = useCookies(['token']) const navigate = useNavigate() - const errorModelName = formData.model_name.trim().length <= 0 - const errorModelDescription = formData.model_description.length < 0 - const errorContextLength = formData.context_length === 0 - const errorLanguage = - formData.model_lang === undefined || formData.model_lang.length === 0 - const errorAbility = - formData.model_ability === undefined || formData.model_ability.length === 0 - const errorModelSize = - formData.model_specs && - formData.model_specs.some((spec) => { - return ( - spec.model_size_in_billions === undefined || - spec.model_size_in_billions === 0 - ) - }) - const errorFamily = familyLabel === '' - const errorAny = - errorModelName || - errorModelDescription || - errorContextLength || - errorLanguage || - errorAbility || - errorModelSize || - errorFamily - useEffect(() => { if (cookie.token === '' || cookie.token === undefined) { return } if (cookie.token === 'need_auth') { navigate('/login', { replace: true }) - return } - const getBuiltinFamilies = async () => { - const response = await fetch(endPoint + '/v1/models/families', { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, - }) - if (!response.ok) { - const errorData = await response.json() // Assuming the server returns error details in JSON format - setErrorMsg( - `Server error: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - } else { - const data = await response.json() - data.chat.push('other') - data.generate.push('other') - setFamily(data) - } - } - - const getBuiltInPromptStyles = async () => { - const response = await fetch(endPoint + '/v1/models/prompts', { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, - }) - if (!response.ok) { - const errorData = await response.json() // Assuming the server returns error details in JSON format - setErrorMsg( - `Server error: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - } else { - const data = await response.json() - let res = [] - for (const key in data) { - let v = data[key] - v['name'] = key - res.push(v) - } - setPromptStyles(res) - } - } - // avoid keep requesting backend to get prompts - if (promptStyles.length === 0) { - getBuiltInPromptStyles().catch((error) => { - setErrorMsg( - error.message || - 'An unexpected error occurred when getting builtin prompt styles.' - ) - console.error('Error: ', error) - }) - } - if (family.chat.length === 0) { - getBuiltinFamilies().catch((error) => { - setErrorMsg( - error.message || - 'An unexpected error occurred when getting builtin prompt styles.' - ) - console.error('Error: ', error) - }) - } }, [cookie.token]) - const getFamilyByAbility = () => { - if (formData.model_ability.includes('chat')) { - return family.chat - } else { - return family.generate - } - } - - const isModelFormatPytorch = () => { - return modelFormat === 'pytorch' - } - - const isModelFormatGPTQ = () => { - return modelFormat === 'gptq' - } - - const isModelFormatAWQ = () => { - return modelFormat === 'awq' - } - - const getPathComponents = (path) => { - const normalizedPath = path.replace(/\\/g, '/') - const baseDir = normalizedPath.substring(0, normalizedPath.lastIndexOf('/')) - const filename = normalizedPath.substring( - normalizedPath.lastIndexOf('/') + 1 - ) - return { baseDir, filename } - } - - const handleClick = async () => { - if (isModelFormatGPTQ()) { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_uri: modelUri, - }, - ] - } else if (isModelFormatAWQ()) { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_uri: modelUri, - }, - ] - } else if (!isModelFormatPytorch()) { - const { baseDir, filename } = getPathComponents(modelUri) - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_file_name_template: filename, - model_uri: baseDir, - }, - ] - } else { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: ['4-bit', '8-bit', 'none'], - model_id: '', - model_uri: modelUri, - }, - ] - } - - formData.model_family = familyLabel - - if (formData.model_ability.includes('chat')) { - const ps = promptStyles.find((item) => item.name === familyLabel) - if (ps) { - formData.prompt_style = { - style_name: ps.style_name, - system_prompt: ps.system_prompt, - roles: ps.roles, - intra_message_sep: ps.intra_message_sep, - inter_message_sep: ps.inter_message_sep, - stop: ps.stop ?? null, - stop_token_ids: ps.stop_token_ids ?? null, - } - } - } - - if (errorAny) { - setErrorMsg('Please fill in valid value for all fields') - return - } - - try { - const response = await fetcher(endPoint + '/v1/model_registrations/LLM', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - model: JSON.stringify(formData), - persist: true, - }), - }) - if (!response.ok) { - const errorData = await response.json() // Assuming the server returns error details in JSON format - setErrorMsg( - `Server error: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - } else { - setSuccessMsg( - 'Model has been registered successfully! Navigate to launch model page to proceed.' - ) - } - } catch (error) { - console.error('There was a problem with the fetch operation:', error) - setErrorMsg(error.message || 'An unexpected error occurred.') - } - } - - const toggleLanguage = (lang) => { - if (formData.model_lang.includes(lang)) { - setFormData({ - ...formData, - model_lang: formData.model_lang.filter((l) => l !== lang), - }) - } else { - setFormData({ - ...formData, - model_lang: [...formData.model_lang, lang], - }) - } - } - - const toggleAbility = (ability) => { - setFamilyLabel('') - if (formData.model_ability.includes(ability)) { - setFormData({ - ...formData, - model_ability: formData.model_ability.filter((a) => a !== ability), - }) - } else { - setFormData({ - ...formData, - model_ability: [...formData.model_ability, ability], - }) - } - } - return ( @@ -336,284 +48,7 @@ const RegisterModel = () => { </TabList> </Box> <TabPanel value="1" sx={{ padding: 0 }}> - <Box padding="20px"></Box> - {/* Base Information */} - <FormControl sx={styles.baseFormControl}> - <TextField - label="Model Name" - error={errorModelName} - defaultValue={formData.model_name} - size="small" - helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." - onChange={(event) => - setFormData({ ...formData, model_name: event.target.value }) - } - /> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - }} - > - Model Format - </label> - - <RadioGroup - value={modelFormat} - onChange={(e) => { - setModelFormat(e.target.value) - }} - > - <Box sx={styles.checkboxWrapper}> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="pytorch" - control={<Radio />} - label="PyTorch" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="ggmlv3" - control={<Radio />} - label="GGML" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="ggufv2" - control={<Radio />} - label="GGUF" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="gptq" - control={<Radio />} - label="GPTQ" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="awq" - control={<Radio />} - label="AWQ" - /> - </Box> - </Box> - </RadioGroup> - <Box padding="15px"></Box> - - <TextField - error={errorContextLength} - label="Context Length" - value={formData.context_length} - size="small" - onChange={(event) => { - let value = event.target.value - // Remove leading zeros - if (/^0+/.test(value)) { - value = value.replace(/^0+/, '') || '0' - } - // Ensure it's a positive integer, if not set it to the minimum - if (!/^\d+$/.test(value) || parseInt(value) < 0) { - value = '0' - } - // Update with the processed value - setFormData({ - ...formData, - context_length: Number(value), - }) - }} - /> - <Box padding="15px"></Box> - - <TextField - label="Model Size in Billions" - size="small" - error={errorModelSize} - value={modelSize} - onChange={(e) => { - let value = e.target.value - // Remove leading zeros - if (/^0+/.test(value)) { - value = value.replace(/^0+/, '') || '0' - } - // Ensure it's a positive integer, if not set it to the minimum - if (!/^\d+$/.test(value) || parseInt(value) < 0) { - value = '0' - } - setModelSize(Number(value)) - }} - /> - <Box padding="15px"></Box> - - <TextField - label="Model Path" - size="small" - value={modelUri} - onChange={(e) => { - setModelUri(e.target.value) - }} - helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path." - /> - <Box padding="15px"></Box> - - <TextField - label="Quantization (Optional)" - size="small" - value={quantization} - onChange={(e) => { - setQuantization(e.target.value) - }} - helperText="For GPTQ/AWQ models, please be careful to fill in the quantization corresponding to the model you want to register." - /> - <Box padding="15px"></Box> - - <TextField - label="Model Description (Optional)" - error={errorModelDescription} - defaultValue={formData.model_description} - size="small" - onChange={(event) => - setFormData({ - ...formData, - model_description: event.target.value, - }) - } - /> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - color: errorLanguage ? ERROR_COLOR : 'inherit', - }} - > - Model Languages - </label> - <Box sx={styles.checkboxWrapper}> - {SUPPORTED_LANGUAGES.map((lang) => ( - <Box key={lang} sx={{ marginRight: '10px' }}> - <FormControlLabel - control={ - <Checkbox - checked={formData.model_lang.includes(lang)} - onChange={() => toggleLanguage(lang)} - name={lang} - sx={ - errorLanguage - ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } - : {} - } - /> - } - label={SUPPORTED_LANGUAGES_DICT[lang]} - style={{ - paddingLeft: 10, - color: errorLanguage ? ERROR_COLOR : 'inherit', - }} - /> - </Box> - ))} - </Box> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - > - Model Abilities - </label> - <Box sx={styles.checkboxWrapper}> - {SUPPORTED_FEATURES.map((ability) => ( - <Box key={ability} sx={{ marginRight: '10px' }}> - <FormControlLabel - control={ - <Checkbox - checked={formData.model_ability.includes( - ability.toLowerCase() - )} - onChange={() => toggleAbility(ability.toLowerCase())} - name={ability} - sx={ - errorAbility - ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } - : {} - } - /> - } - label={ability} - style={{ - paddingLeft: 10, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - /> - </Box> - ))} - </Box> - <Box padding="15px"></Box> - </FormControl> - - <FormControl sx={styles.baseFormControl}> - <label - style={{ - paddingLeft: 5, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - > - Model Family - </label> - <FormHelperText> - Please be careful to select the family name corresponding to the - model you want to register. If not found, please choose `other`. - </FormHelperText> - <RadioGroup - value={familyLabel} - onChange={(e) => { - setFamilyLabel(e.target.value) - }} - > - <Box sx={styles.checkboxWrapper}> - {getFamilyByAbility().map((v) => ( - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel value={v} control={<Radio />} label={v} /> - </Box> - ))} - </Box> - </RadioGroup> - <Box padding="15px"></Box> - </FormControl> - - <Box width={'100%'}> - {successMsg !== '' && ( - <Alert severity="success"> - <AlertTitle>Success</AlertTitle> - {successMsg} - </Alert> - )} - <Button - variant="contained" - color="primary" - type="submit" - onClick={handleClick} - > - Register Model - </Button> - </Box> + <RegisterLanguageModel /> </TabPanel> <TabPanel value="2" sx={{ padding: 0 }}> <RegisterEmbeddingModel /> @@ -627,32 +62,3 @@ const RegisterModel = () => { } export default RegisterModel - -const styles = { - baseFormControl: { - width: '100%', - margin: 'normal', - size: 'small', - }, - checkboxWrapper: { - display: 'flex', - flexWrap: 'wrap', - maxWidth: '80%', - }, - labelPaddingLeft: { - paddingLeft: 5, - }, - formControlLabelPaddingLeft: { - paddingLeft: 10, - }, - buttonBox: { - width: '100%', - margin: '20px', - }, - error: { - fontWeight: 'bold', - margin: '5px 0', - padding: '1px', - borderRadius: '5px', - }, -} diff --git a/xinference/web/ui/src/scenes/register_model/register_language.js b/xinference/web/ui/src/scenes/register_model/register_language.js new file mode 100644 index 0000000000..eed7280340 --- /dev/null +++ b/xinference/web/ui/src/scenes/register_model/register_language.js @@ -0,0 +1,620 @@ +import { + Box, + Checkbox, + FormControl, + FormControlLabel, + FormHelperText, + Radio, + RadioGroup, +} from '@mui/material' +import Alert from '@mui/material/Alert' +import AlertTitle from '@mui/material/AlertTitle' +import Button from '@mui/material/Button' +import TextField from '@mui/material/TextField' +import React, { useContext, useEffect, useState } from 'react' +import { useCookies } from 'react-cookie' + +import { ApiContext } from '../../components/apiContext' +import fetcher from '../../components/fetcher' +import { useMode } from '../../theme' + +const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } +const SUPPORTED_FEATURES = ['Generate', 'Chat'] + +// Convert dictionary of supported languages into list +const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) + +const RegisterLanguageModel = () => { + const ERROR_COLOR = useMode() + const endPoint = useContext(ApiContext).endPoint + const { setErrorMsg } = useContext(ApiContext) + const [successMsg, setSuccessMsg] = useState('') + const [modelFormat, setModelFormat] = useState('pytorch') + const [modelSize, setModelSize] = useState(7) + const [modelUri, setModelUri] = useState('/path/to/llama-2') + const [quantization, setQuantization] = useState('') + const [formData, setFormData] = useState({ + version: 1, + context_length: 2048, + model_name: 'custom-llama-2', + model_lang: ['en'], + model_ability: ['generate'], + model_description: 'This is a custom model description.', + model_family: '', + model_specs: [], + prompt_style: undefined, + }) + const [promptStyles, setPromptStyles] = useState([]) + const [family, setFamily] = useState({ + chat: [], + generate: [], + }) + const [familyLabel, setFamilyLabel] = useState('') + + const [cookie] = useCookies(['token']) + const errorModelName = formData.model_name.trim().length <= 0 + const errorModelDescription = formData.model_description.length < 0 + const errorContextLength = formData.context_length === 0 + const errorLanguage = + formData.model_lang === undefined || formData.model_lang.length === 0 + const errorAbility = + formData.model_ability === undefined || formData.model_ability.length === 0 + const errorModelSize = + formData.model_specs && + formData.model_specs.some((spec) => { + return ( + spec.model_size_in_billions === undefined || + spec.model_size_in_billions === 0 + ) + }) + const errorFamily = familyLabel === '' + const errorAny = + errorModelName || + errorModelDescription || + errorContextLength || + errorLanguage || + errorAbility || + errorModelSize || + errorFamily + + useEffect(() => { + if (cookie.token === '' || cookie.token === undefined) { + return + } + + const getBuiltinFamilies = async () => { + const response = await fetch(endPoint + '/v1/models/families', { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }` + ) + } else { + const data = await response.json() + data.chat.push('other') + data.generate.push('other') + setFamily(data) + } + } + + const getBuiltInPromptStyles = async () => { + const response = await fetch(endPoint + '/v1/models/prompts', { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }` + ) + } else { + const data = await response.json() + let res = [] + for (const key in data) { + let v = data[key] + v['name'] = key + res.push(v) + } + setPromptStyles(res) + } + } + // avoid keep requesting backend to get prompts + if (promptStyles.length === 0) { + getBuiltInPromptStyles().catch((error) => { + setErrorMsg( + error.message || + 'An unexpected error occurred when getting builtin prompt styles.' + ) + console.error('Error: ', error) + }) + } + if (family.chat.length === 0) { + getBuiltinFamilies().catch((error) => { + setErrorMsg( + error.message || + 'An unexpected error occurred when getting builtin prompt styles.' + ) + console.error('Error: ', error) + }) + } + }, [cookie.token]) + + const getFamilyByAbility = () => { + if (formData.model_ability.includes('chat')) { + return family.chat + } else { + return family.generate + } + } + + const isModelFormatPytorch = () => { + return modelFormat === 'pytorch' + } + + const isModelFormatGPTQ = () => { + return modelFormat === 'gptq' + } + + const isModelFormatAWQ = () => { + return modelFormat === 'awq' + } + + const getPathComponents = (path) => { + const normalizedPath = path.replace(/\\/g, '/') + const baseDir = normalizedPath.substring(0, normalizedPath.lastIndexOf('/')) + const filename = normalizedPath.substring( + normalizedPath.lastIndexOf('/') + 1 + ) + return { baseDir, filename } + } + + const handleClick = async () => { + if (isModelFormatGPTQ()) { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_uri: modelUri, + }, + ] + } else if (isModelFormatAWQ()) { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_uri: modelUri, + }, + ] + } else if (!isModelFormatPytorch()) { + const { baseDir, filename } = getPathComponents(modelUri) + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_file_name_template: filename, + model_uri: baseDir, + }, + ] + } else { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: ['4-bit', '8-bit', 'none'], + model_id: '', + model_uri: modelUri, + }, + ] + } + + formData.model_family = familyLabel + + if (formData.model_ability.includes('chat')) { + const ps = promptStyles.find((item) => item.name === familyLabel) + if (ps) { + formData.prompt_style = { + style_name: ps.style_name, + system_prompt: ps.system_prompt, + roles: ps.roles, + intra_message_sep: ps.intra_message_sep, + inter_message_sep: ps.inter_message_sep, + stop: ps.stop ?? null, + stop_token_ids: ps.stop_token_ids ?? null, + } + } + } + + if (errorAny) { + setErrorMsg('Please fill in valid value for all fields') + return + } + + try { + const response = await fetcher(endPoint + '/v1/model_registrations/LLM', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: JSON.stringify(formData), + persist: true, + }), + }) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }` + ) + } else { + setSuccessMsg( + 'Model has been registered successfully! Navigate to launch model page to proceed.' + ) + } + } catch (error) { + console.error('There was a problem with the fetch operation:', error) + setErrorMsg(error.message || 'An unexpected error occurred.') + } + } + + const toggleLanguage = (lang) => { + if (formData.model_lang.includes(lang)) { + setFormData({ + ...formData, + model_lang: formData.model_lang.filter((l) => l !== lang), + }) + } else { + setFormData({ + ...formData, + model_lang: [...formData.model_lang, lang], + }) + } + } + + const toggleAbility = (ability) => { + setFamilyLabel('') + if (formData.model_ability.includes(ability)) { + setFormData({ + ...formData, + model_ability: formData.model_ability.filter((a) => a !== ability), + }) + } else { + setFormData({ + ...formData, + model_ability: [...formData.model_ability, ability], + }) + } + } + + return ( + <React.Fragment> + <Box padding="20px"></Box> + {/* Base Information */} + <FormControl sx={styles.baseFormControl}> + <TextField + label="Model Name" + error={errorModelName} + defaultValue={formData.model_name} + size="small" + helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." + onChange={(event) => + setFormData({ ...formData, model_name: event.target.value }) + } + /> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + }} + > + Model Format + </label> + + <RadioGroup + value={modelFormat} + onChange={(e) => { + setModelFormat(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="pytorch" + control={<Radio />} + label="PyTorch" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="ggmlv3" + control={<Radio />} + label="GGML" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="ggufv2" + control={<Radio />} + label="GGUF" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="gptq" + control={<Radio />} + label="GPTQ" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="awq" + control={<Radio />} + label="AWQ" + /> + </Box> + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + <TextField + error={errorContextLength} + label="Context Length" + value={formData.context_length} + size="small" + onChange={(event) => { + let value = event.target.value + // Remove leading zeros + if (/^0+/.test(value)) { + value = value.replace(/^0+/, '') || '0' + } + // Ensure it's a positive integer, if not set it to the minimum + if (!/^\d+$/.test(value) || parseInt(value) < 0) { + value = '0' + } + // Update with the processed value + setFormData({ + ...formData, + context_length: Number(value), + }) + }} + /> + <Box padding="15px"></Box> + + <TextField + label="Model Size in Billions" + size="small" + error={errorModelSize} + value={modelSize} + onChange={(e) => { + let value = e.target.value + // Remove leading zeros + if (/^0+/.test(value)) { + value = value.replace(/^0+/, '') || '0' + } + // Ensure it's a positive integer, if not set it to the minimum + if (!/^\d+$/.test(value) || parseInt(value) < 0) { + value = '0' + } + setModelSize(Number(value)) + }} + /> + <Box padding="15px"></Box> + + <TextField + label="Model Path" + size="small" + value={modelUri} + onChange={(e) => { + setModelUri(e.target.value) + }} + helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path." + /> + <Box padding="15px"></Box> + + <TextField + label="Quantization (Optional)" + size="small" + value={quantization} + onChange={(e) => { + setQuantization(e.target.value) + }} + helperText="For GPTQ/AWQ models, please be careful to fill in the quantization corresponding to the model you want to register." + /> + <Box padding="15px"></Box> + + <TextField + label="Model Description (Optional)" + error={errorModelDescription} + defaultValue={formData.model_description} + size="small" + onChange={(event) => + setFormData({ + ...formData, + model_description: event.target.value, + }) + } + /> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + > + Model Languages + </label> + <Box sx={styles.checkboxWrapper}> + {SUPPORTED_LANGUAGES.map((lang) => ( + <Box key={lang} sx={{ marginRight: '10px' }}> + <FormControlLabel + control={ + <Checkbox + checked={formData.model_lang.includes(lang)} + onChange={() => toggleLanguage(lang)} + name={lang} + sx={ + errorLanguage + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } + : {} + } + /> + } + label={SUPPORTED_LANGUAGES_DICT[lang]} + style={{ + paddingLeft: 10, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + /> + </Box> + ))} + </Box> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + > + Model Abilities + </label> + <Box sx={styles.checkboxWrapper}> + {SUPPORTED_FEATURES.map((ability) => ( + <Box key={ability} sx={{ marginRight: '10px' }}> + <FormControlLabel + control={ + <Checkbox + checked={formData.model_ability.includes( + ability.toLowerCase() + )} + onChange={() => toggleAbility(ability.toLowerCase())} + name={ability} + sx={ + errorAbility + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } + : {} + } + /> + } + label={ability} + style={{ + paddingLeft: 10, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + /> + </Box> + ))} + </Box> + <Box padding="15px"></Box> + </FormControl> + + <FormControl sx={styles.baseFormControl}> + <label + style={{ + paddingLeft: 5, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + > + Model Family + </label> + <FormHelperText> + Please be careful to select the family name corresponding to the + model you want to register. If not found, please choose `other`. + </FormHelperText> + <RadioGroup + value={familyLabel} + onChange={(e) => { + setFamilyLabel(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + {getFamilyByAbility().map((v) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel value={v} control={<Radio />} label={v} /> + </Box> + ))} + </Box> + </RadioGroup> + <Box padding="15px"></Box> + </FormControl> + + <Box width={'100%'}> + {successMsg !== '' && ( + <Alert severity="success"> + <AlertTitle>Success</AlertTitle> + {successMsg} + </Alert> + )} + <Button + variant="contained" + color="primary" + type="submit" + onClick={handleClick} + > + Register Model + </Button> + </Box> + </React.Fragment> + ) +} + +export default RegisterLanguageModel + +const styles = { + baseFormControl: { + width: '100%', + margin: 'normal', + size: 'small', + }, + checkboxWrapper: { + display: 'flex', + flexWrap: 'wrap', + maxWidth: '80%', + }, + labelPaddingLeft: { + paddingLeft: 5, + }, + formControlLabelPaddingLeft: { + paddingLeft: 10, + }, + buttonBox: { + width: '100%', + margin: '20px', + }, + error: { + fontWeight: 'bold', + margin: '5px 0', + padding: '1px', + borderRadius: '5px', + }, +} From 18f464eeb25d814c59a69604f0d29bfac1fcf577 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Fri, 15 Mar 2024 18:52:08 +0800 Subject: [PATCH 02/21] =?UTF-8?q?WIP:=20=E4=BB=8EHUB=E4=B8=AD=E8=8E=B7?= =?UTF-8?q?=E5=BE=97=E6=A8=A1=E5=9E=8B=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xinference/core/supervisor.py | 62 +- xinference/core/tests/test_utils.py | 200 ++++++ xinference/core/utils.py | 181 ++++- xinference/model/llm/llm_family.py | 8 + .../register_model/register_language.js | 622 ++++++++++-------- 5 files changed, 803 insertions(+), 270 deletions(-) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 615eade853..7c8e4739ec 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -14,6 +14,7 @@ import asyncio import itertools +import json import time import typing from dataclasses import dataclass @@ -21,6 +22,9 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union import xoscar as xo +from huggingface_hub import HfApi +from huggingface_hub.hf_api import ModelInfo +from typing_extensions import Literal from ..constants import ( XINFERENCE_DISABLE_HEALTH_CHECK, @@ -30,11 +34,14 @@ ) from ..core import ModelActor from ..core.status_guard import InstanceInfo, LaunchStatus +from ..model.llm import GgmlLLMSpecV1 +from ..model.llm.llm_family import HubImportLLMFamilyV1 from .metrics import record_metrics from .resource import GPUStatus, ResourceStatus from .utils import ( build_replica_model_uid, gen_random_string, + get_model_size_from_model_id, is_valid_model_uid, iter_replica_model_uid, log_async, @@ -51,10 +58,8 @@ from ..model.rerank import RerankModelSpec from .worker import WorkerActor - logger = getLogger(__name__) - ASYNC_LAUNCH_TASKS = {} # type: ignore @@ -79,6 +84,7 @@ class ReplicaInfo: class SupervisorActor(xo.StatelessActor): def __init__(self): super().__init__() + self.__hf_api: Optional[HfApi] = None self._worker_address_to_worker: Dict[str, xo.ActorRefType["WorkerActor"]] = {} self._worker_status: Dict[str, WorkerStatus] = {} self._replica_model_uid_to_worker: Dict[ @@ -974,3 +980,55 @@ async def report_worker_status( @staticmethod def record_metrics(name, op, kwargs): record_metrics(name, op, kwargs) + + def __get_hf_api(self): + if self.__hf_api is None: + self.__hf_api = HfApi() + return self.__hf_api + + @log_async(logger=logger) + async def get_llm_spec( + self, + model_id: str, + model_format: Literal["pytorch", "ggmlv3", "ggufv2", "gptq", "awq"], + model_hub: str, + ) -> HubImportLLMFamilyV1: + llm_family = HubImportLLMFamilyV1(version=1) + if model_hub == "huggingface": + api = self.__get_hf_api() + + model_info: ModelInfo = await asyncio.wrap_future( + api.run_as_future(api.model_info, model_id) + ) + logger.info(f"Model info: {model_info}") + + if await asyncio.wrap_future( + api.run_as_future(api.file_exists, model_id, "config.json") + ): + config_path = await asyncio.wrap_future( + api.run_as_future(api.hf_hub_download, model_id, "config.json") + ) + with open(config_path) as f: + config = json.load(f) + if "max_position_embeddings" in config: + llm_family.context_length = config["max_position_embeddings"] + + if model_format in ["pytorch", "gptq", "awq"]: + pass + elif model_format in ["ggmlv3", "ggufv2"]: + llm_spec = GgmlLLMSpecV1() + llm_family.model_specs.append(llm_spec) + llm_spec.model_id = model_id + llm_spec.model_format = model_format + llm_spec.model_hub = model_hub + llm_spec.model_size_in_billions = get_model_size_from_model_id(model_id) + + else: + raise ValueError(f"Unsupported model format: {model_format}") + + elif model_hub == "modelscope": + pass + else: + raise ValueError(f"Unsupported model hub: {model_hub}") + + return llm_family diff --git a/xinference/core/tests/test_utils.py b/xinference/core/tests/test_utils.py index ce94e50a35..f4adbaf629 100644 --- a/xinference/core/tests/test_utils.py +++ b/xinference/core/tests/test_utils.py @@ -13,7 +13,12 @@ # limitations under the License. from ..utils import ( + SUPPORTED_QUANTIZATIONS, build_replica_model_uid, + get_llama_cpp_quantization_info, + get_match_quantization_filenames, + get_model_size_from_model_id, + get_prefix_suffix, iter_replica_model_uid, parse_replica_model_uid, ) @@ -29,3 +34,198 @@ def test_replica_model_uid(): all_gen_ids.append(replica_model_uid) assert len(all_gen_ids) == 5 assert len(set(all_gen_ids)) == 5 + + +def test_get_model_size_from_model_id(): + model_id = "froggeric/WestLake-10.7B-v2-GGUF" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "10.7B" + + model_id = "m-a-p/OpenCodeInterpreter-DS-33B" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "33B" + + model_id = "MBZUAI/MobiLlama-05B" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "0.5B" + + model_id = "ibivibiv/alpaca-dragon-72b-v1" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "72B" + + model_id = "ISTA-DASLab/Mixtral-8x7B-Instruct-v0_1-AQLM-2Bit-1x16-hf" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "7B" + + model_id = "internlm/internlm-xcomposer2-vl-7b-4bit" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "7B" + + model_id = "ahxt/LiteLlama-460M-1T" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "0.46B" + + model_id = "Dracones/Midnight-Miqu-70B-v1.0_exl2_2.24bpw" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "70B" + + model_id = "MaziyarPanahi/MixTAO-7Bx2-MoE-v8.1-GGUF" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "7B" + + model_id = "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "7B" + + model_id = "stabilityai/stablelm-2-zephyr-1_6b" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "1.6B" + + model_id = "Qwen/Qwen1.5-Chat-4bit-GPTQ-72B" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "72B" + + model_id = "m-a-p/OpenCodeInterpreter-3Bee-DS-33B" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "33B" + + model_id = "mlx-community/c4ai-command-r-v01-4bit" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "UNKNOWN" + + model_id = "lemonilia/ShoriRP-v0.75d" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "UNKNOWN" + + model_id = "abc" + try: + get_model_size_from_model_id(model_id) + assert False + except ValueError: + pass + + +def test_get_match_quantization_filenames(): + filenames = [ + "kafkalm-70b-german-v0.1.Q2_K.gguf", + "kafkalm-70b-german-v0.1.Q3_K_L.gguf", + "kafkalm-70b-german-v0.1.Q3_K_M.gguf", + "kafkalm-70b-german-v0.1.Q3_K_S.gguf", + "kafkalm-70b-german-v0.1.Q4_0.gguf", + "kafkalm-70b-german-v0.1.Q4_K_M.gguf", + "kafkalm-70b-german-v0.1.Q4_K_S.gguf", + "kafkalm-70b-german-v0.1.Q5_K_M.gguf", + "kafkalm-70b-german-v0.1.Q5_K_S.gguf", + "kafkalm-70b-german-v0.1.Q6_K.gguf-split-a", + "kafkalm-70b-german-v0.1.Q6_K.gguf-split-b", + "kafkalm-70b-german-v0.1.Q8_0.gguf-split-a", + "kafkalm-70b-german-v0.1.Q8_0.gguf-split-b", + ] + + results = get_match_quantization_filenames(filenames) + assert len(results) == 13 + assert all(x[0][: x[2]] == "kafkalm-70b-german-v0.1." for x in results) + assert all(x[1].upper() in SUPPORTED_QUANTIZATIONS for x in results) + assert results[0][0][results[0][2] + len(results[0][1]) :] == ".gguf" + assert results[-1][0][results[-1][2] + len(results[-1][1]) :] == ".gguf-split-b" + + +def test_get_prefix_suffix(): + names = [ + ".gguf-split-a", + ".gguf-split-b", + ".gguf-split-a", + ".gguf-split-b", + ".gguf-split-c", + ] + prefix, suffix = get_prefix_suffix(names) + assert prefix == ".gguf-split-" + assert suffix == "" + + names = ["-part-a.gguf", "-part-b.gguf", "-part-c.gguf", "-part-a.gguf"] + + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-part-" + assert suffix == ".gguf" + + names = ["-part-1.gguf", "-part-2.gguf", "-part-12.gguf", "-part-2.gguf"] + + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-part-" + assert suffix == ".gguf" + + names = [".gguf", "-part-1.gguf", "-part-2.gguf", "-part-12.gguf", "-part-2.gguf"] + prefix, suffix = get_prefix_suffix(names) + assert prefix == "" + assert suffix == ".gguf" + + names = [ + "-test.gguf", + "-test-part-1.gguf", + "-test-part-2.gguf", + "-test-part-12.gguf", + "-test-part-2.gguf", + ] + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-test" + assert suffix == ".gguf" + + names = ["-part-1.gguf", "-part-1.gguf", "-part-1.gguf"] + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-part-1.gguf" + assert suffix == "" + + prefix, suffix = get_prefix_suffix([]) + assert prefix == "" + assert suffix == "" + + names = ["-only-1.gguf"] + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-only-1.gguf" + assert suffix == "" + + +def test_get_llama_cpp_quantization_info(): + filenames = [ + "kafkalm-70b-german-v0.1.Q2_K.gguf", + "kafkalm-70b-german-v0.1.Q3_K_L.gguf", + "kafkalm-70b-german-v0.1.Q3_K_M.gguf", + "kafkalm-70b-german-v0.1.Q3_K_S.gguf", + "kafkalm-70b-german-v0.1.Q4_0.gguf", + "kafkalm-70b-german-v0.1.Q4_K_M.gguf", + "kafkalm-70b-german-v0.1.Q4_K_S.gguf", + "kafkalm-70b-german-v0.1.Q5_K_M.gguf", + "kafkalm-70b-german-v0.1.Q5_K_S.gguf", + "kafkalm-70b-german-v0.1.Q6_K.gguf-split-a", + "kafkalm-70b-german-v0.1.Q6_K.gguf-split-b", + "kafkalm-70b-german-v0.1.Q8_0.gguf-split-a", + "kafkalm-70b-german-v0.1.Q8_0.gguf-split-b", + ] + + tpl1, tpl2 = get_llama_cpp_quantization_info(filenames[:-4], "ggufv2") + assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.gguf" + assert tpl2 is None + + tpl1, tpl2 = get_llama_cpp_quantization_info(filenames, "ggufv2") + assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.gguf" + assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.gguf-split-{part}" + + filenames = [ + "kafkalm-70b-german-v0.1.Q2_K.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_L.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q4_0.test.gguf", + "kafkalm-70b-german-v0.1.Q4_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q4_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q5_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q5_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q6_K.test-split-a.gguf", + "kafkalm-70b-german-v0.1.Q6_K.test-split-b.gguf", + "kafkalm-70b-german-v0.1.Q8_0.test-split-a.gguf", + "kafkalm-70b-german-v0.1.Q8_0.test-split-b.gguf", + ] + + tpl1, tpl2 = get_llama_cpp_quantization_info(filenames, "ggufv2") + assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf" + assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.test-split-{part}.gguf" diff --git a/xinference/core/utils.py b/xinference/core/utils.py index 0a121f4769..483296ca0c 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -15,11 +15,13 @@ import logging import os import random +import re import string -from typing import Dict, Generator, List, Tuple, Union +from typing import Dict, Generator, Iterable, List, Optional, Tuple, Union import orjson from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown +from typing_extensions import Literal from .._compat import BaseModel @@ -191,3 +193,180 @@ def get_nvidia_gpu_info() -> Dict: nvmlShutdown() except: pass + + +def get_model_size_from_model_id(model_id: str) -> str: + """ + Get model size from model_id. + + Args: + model_id: model_id in format of `user/repo` + + Returns: + model size in format of `100B`, if size is in M, divide into 1000 and return as B. + For example, `100M` will be returned as `0.1B`. + + If there is no model size in the repo name, return `UNKNOWN`. + """ + + def resize_to_billion(size: str) -> str: + if size.lower().endswith("m"): + return str(round(int(size[:-1]) / 1000, 2)).rstrip("0") + "B" + if size[0] == "0": + size = size[0] + "." + str(size[1:]) + return size.replace("_", ".").upper() + + split = model_id.split("/") + if len(split) != 2: + raise ValueError(f"Cannot parse model_id: {model_id}") + user, repo = split + segs = repo.split("-") + param_pattern = re.compile(r"\d+(?:[._]\d+)?[bm]", re.I) + partial_matched = "UNKNOWN" + for seg in segs: + if m := param_pattern.search(seg): + if m.start() == 0 and m.end() == len(seg): + return resize_to_billion(seg) + else: + # only match the first partial matched, and do not match `bit` for quantization mode + if ( + partial_matched == "UNKNOWN" + and seg[m.end(0) : m.end(0) + 2].lower() != "it" + ): + partial_matched = m.group(0) + return resize_to_billion(partial_matched) + + +SUPPORTED_QUANTIZATIONS = [ + "Q3_K_S", + "Q3_K_M", + "Q3_K_L", + "Q4_K_S", + "Q4_K_M", + "Q5_K_S", + "Q5_K_M", + "Q6_K", + "F32", + "F16", + "Q4_0", + "Q4_1", + "Q8_0", + "Q5_0", + "Q5_1", + "Q2_K", +] + + +def get_match_quantization_filenames( + filenames: List[str], +) -> List[Tuple[str, str, int]]: + results: List[Tuple[str, str, int]] = [] + for filename in filenames: + for quantization in SUPPORTED_QUANTIZATIONS: + if (index := filename.upper().find(quantization)) != -1: + results.append((filename, quantization, index)) + return results + + +def get_prefix_suffix(names: Iterable[str]) -> Tuple[str, str]: + if len(list(names)) == 0: + return "", "" + + # if all names are the same, or only one name, return the first name as prefix and suffix is empty + if len(set(names)) == 1: + return list(names)[0], "" + + min_len = min(map(len, names)) + name = [n for n in names if len(n) == min_len][0] + + for i in range(min_len): + if len(set(map(lambda x: x[: i + 1], names))) > 1: + prefix = name[:i] + break + else: + prefix = name + + for i in range(min_len): + if len(set(map(lambda x: x[-i - 1 :], names))) > 1: + suffix = name[len(name) - i :] + break + else: + suffix = name + + return prefix, suffix + + +def get_llama_cpp_quantization_info( + filenames: List[str], model_type: Literal["ggmlv3", "ggufv2"] +) -> Tuple[Optional[str], Optional[str]]: + model_file_name_template = None + model_file_name_split_template: Optional[str] = None + if model_type == "ggmlv3": + filenames = [ + filename + for filename in filenames + if filename.lower().endswith(".bin") or "ggml" in filename.lower() + ] + elif model_type == "ggufv2": + filenames = [filename for filename in filenames if ".gguf" in filename] + else: + raise ValueError(f"Unsupported model type: {model_type}") + + matched = get_match_quantization_filenames(filenames) + + if len(matched) == 0: + raise ValueError("Cannot find any quantization files in this") + + prefixes = set() + suffixes = set() + + for filename, quantization, index in matched: + prefixes.add(filename[:index]) + suffixes.add(filename[index + len(quantization) :]) + + if len(prefixes) == 1 and len(suffixes) == 1: + model_file_name_template = prefixes.pop() + "{quantization}" + suffixes.pop() + + elif len(prefixes) == 1 and len(suffixes) > 1: + shortest_suffix = min(suffixes, key=len) + part_prefix, part_suffix = get_prefix_suffix(suffixes) + if shortest_suffix == part_prefix + part_suffix: + model_file_name_template = ( + list(prefixes)[0] + "{quantization}" + shortest_suffix + ) + part_prefix, part_suffix = get_prefix_suffix( + [suffix for suffix in suffixes if suffix != shortest_suffix] + ) + model_file_name_split_template = ( + prefixes.pop() + "{quantization}" + part_prefix + "{part}" + part_suffix + ) + else: + model_file_name_split_template = ( + prefixes.pop() + "{quantization}" + part_prefix + "{part}" + part_suffix + ) + + elif len(prefixes) > 1 and len(suffixes) == 1: + shortest_prefix = min(prefixes, key=len) + part_prefix, part_suffix = get_prefix_suffix(prefixes) + if shortest_prefix == part_prefix + part_suffix: + model_file_name_template = ( + shortest_prefix + "{quantization}" + list(suffixes)[0] + ) + part_prefix, part_suffix = get_prefix_suffix( + [prefix for prefix in prefixes if prefix != shortest_prefix] + ) + model_file_name_split_template = ( + part_prefix + + "{quantization}" + + shortest_prefix + + "{part}" + + part_suffix + ) + else: + model_file_name_split_template = ( + prefixes.pop() + "{quantization}" + part_prefix + "{part}" + part_suffix + ) + else: + logger.info("Cannot find a valid template for model file names") + + return model_file_name_template, model_file_name_split_template diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 15ff0db84c..e2065fe1e4 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -122,6 +122,13 @@ class LLMFamilyV1(BaseModel): prompt_style: Optional["PromptStyleV1"] +class HubImportLLMFamilyV1(BaseModel): + version: Literal[1] + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH + model_specs: List["LLMSpecV1"] + prompt_style: Optional["PromptStyleV1"] + + class CustomLLMFamilyV1(LLMFamilyV1): prompt_style: Optional[Union["PromptStyleV1", str]] # type: ignore @@ -208,6 +215,7 @@ def parse_raw( ] LLMFamilyV1.update_forward_refs() +HubImportLLMFamilyV1.update_forward_refs() CustomLLMFamilyV1.update_forward_refs() diff --git a/xinference/web/ui/src/scenes/register_model/register_language.js b/xinference/web/ui/src/scenes/register_model/register_language.js index eed7280340..606fcecee2 100644 --- a/xinference/web/ui/src/scenes/register_model/register_language.js +++ b/xinference/web/ui/src/scenes/register_model/register_language.js @@ -4,8 +4,11 @@ import { FormControl, FormControlLabel, FormHelperText, + InputLabel, + MenuItem, Radio, RadioGroup, + Select, } from '@mui/material' import Alert from '@mui/material/Alert' import AlertTitle from '@mui/material/AlertTitle' @@ -21,6 +24,12 @@ import { useMode } from '../../theme' const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } const SUPPORTED_FEATURES = ['Generate', 'Chat'] +const SUPPORTED_HUBS_DICT = { huggingface: 'HuggingFace', modelscope: 'ModelScope' } +const SUPPORTED_HUBS = Object.keys(SUPPORTED_HUBS_DICT) + +const SOURCES_DICT = { self_hosted: 'Self Hosted', hub: 'Hub' } +const SOURCES = Object.keys(SOURCES_DICT) + // Convert dictionary of supported languages into list const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) @@ -32,7 +41,10 @@ const RegisterLanguageModel = () => { const [modelFormat, setModelFormat] = useState('pytorch') const [modelSize, setModelSize] = useState(7) const [modelUri, setModelUri] = useState('/path/to/llama-2') + const [modelId, setModelId] = useState('') const [quantization, setQuantization] = useState('') + const [modelSource, setModelSource] = useState(SOURCES[0]) + const [hub, setHub] = useState(SUPPORTED_HUBS[0]) const [formData, setFormData] = useState({ version: 1, context_length: 2048, @@ -94,7 +106,7 @@ const RegisterLanguageModel = () => { setErrorMsg( `Server error: ${response.status} - ${ errorData.detail || 'Unknown error' - }` + }`, ) } else { const data = await response.json() @@ -116,7 +128,7 @@ const RegisterLanguageModel = () => { setErrorMsg( `Server error: ${response.status} - ${ errorData.detail || 'Unknown error' - }` + }`, ) } else { const data = await response.json() @@ -134,7 +146,7 @@ const RegisterLanguageModel = () => { getBuiltInPromptStyles().catch((error) => { setErrorMsg( error.message || - 'An unexpected error occurred when getting builtin prompt styles.' + 'An unexpected error occurred when getting builtin prompt styles.', ) console.error('Error: ', error) }) @@ -143,14 +155,14 @@ const RegisterLanguageModel = () => { getBuiltinFamilies().catch((error) => { setErrorMsg( error.message || - 'An unexpected error occurred when getting builtin prompt styles.' + 'An unexpected error occurred when getting builtin prompt styles.', ) console.error('Error: ', error) }) } }, [cookie.token]) - const getFamilyByAbility = () => { + const getFamilyByAbility = () => { if (formData.model_ability.includes('chat')) { return family.chat } else { @@ -174,7 +186,7 @@ const RegisterLanguageModel = () => { const normalizedPath = path.replace(/\\/g, '/') const baseDir = normalizedPath.substring(0, normalizedPath.lastIndexOf('/')) const filename = normalizedPath.substring( - normalizedPath.lastIndexOf('/') + 1 + normalizedPath.lastIndexOf('/') + 1, ) return { baseDir, filename } } @@ -262,11 +274,11 @@ const RegisterLanguageModel = () => { setErrorMsg( `Server error: ${response.status} - ${ errorData.detail || 'Unknown error' - }` + }`, ) } else { setSuccessMsg( - 'Model has been registered successfully! Navigate to launch model page to proceed.' + 'Model has been registered successfully! Navigate to launch model page to proceed.', ) } } catch (error) { @@ -275,6 +287,10 @@ const RegisterLanguageModel = () => { } } + const handleImportModel = async () => { + console.log('import model') + } + const toggleLanguage = (lang) => { if (formData.model_lang.includes(lang)) { setFormData({ @@ -306,285 +322,357 @@ const RegisterLanguageModel = () => { return ( <React.Fragment> - <Box padding="20px"></Box> - {/* Base Information */} - <FormControl sx={styles.baseFormControl}> - <TextField - label="Model Name" - error={errorModelName} - defaultValue={formData.model_name} - size="small" - helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." - onChange={(event) => - setFormData({ ...formData, model_name: event.target.value }) - } - /> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - }} - > - Model Format - </label> - - <RadioGroup - value={modelFormat} - onChange={(e) => { - setModelFormat(e.target.value) - }} - > - <Box sx={styles.checkboxWrapper}> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="pytorch" - control={<Radio />} - label="PyTorch" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="ggmlv3" - control={<Radio />} - label="GGML" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="ggufv2" - control={<Radio />} - label="GGUF" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="gptq" - control={<Radio />} - label="GPTQ" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="awq" - control={<Radio />} - label="AWQ" - /> - </Box> + <Box padding="20px"></Box> + {/* Base Information */} + <FormControl sx={styles.baseFormControl}> + <TextField + label="Model Name" + error={errorModelName} + defaultValue={formData.model_name} + size="small" + helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." + onChange={(event) => + setFormData({ ...formData, model_name: event.target.value }) + } + /> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + }} + > + Model Format + </label> + + <RadioGroup + value={modelFormat} + onChange={(e) => { + setModelFormat(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="pytorch" + control={<Radio />} + label="PyTorch" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="ggmlv3" + control={<Radio />} + label="GGML" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="ggufv2" + control={<Radio />} + label="GGUF" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="gptq" + control={<Radio />} + label="GPTQ" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="awq" + control={<Radio />} + label="AWQ" + /> + </Box> + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + }} + > + Model Source + </label> + + <RadioGroup + value={modelSource} + onChange={(e) => { + setModelSource(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + {SOURCES.map((item) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value={item} + control={<Radio />} + label={SOURCES_DICT[item]} + /> </Box> - </RadioGroup> - <Box padding="15px"></Box> - - <TextField - error={errorContextLength} - label="Context Length" - value={formData.context_length} - size="small" - onChange={(event) => { - let value = event.target.value - // Remove leading zeros - if (/^0+/.test(value)) { - value = value.replace(/^0+/, '') || '0' - } - // Ensure it's a positive integer, if not set it to the minimum - if (!/^\d+$/.test(value) || parseInt(value) < 0) { - value = '0' - } - // Update with the processed value - setFormData({ - ...formData, - context_length: Number(value), - }) - }} - /> - <Box padding="15px"></Box> - - <TextField - label="Model Size in Billions" - size="small" - error={errorModelSize} - value={modelSize} - onChange={(e) => { - let value = e.target.value - // Remove leading zeros - if (/^0+/.test(value)) { - value = value.replace(/^0+/, '') || '0' - } - // Ensure it's a positive integer, if not set it to the minimum - if (!/^\d+$/.test(value) || parseInt(value) < 0) { - value = '0' - } - setModelSize(Number(value)) - }} - /> - <Box padding="15px"></Box> + ))} + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + {modelSource === 'self_hosted' && + <TextField + label="Model Path" + size="small" + value={modelUri} + onChange={(e) => { + setModelUri(e.target.value) + }} + helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path." + />} + {modelSource === 'hub' && + <Box sx={styles.checkboxWrapper}> - <TextField - label="Model Path" - size="small" - value={modelUri} - onChange={(e) => { - setModelUri(e.target.value) - }} - helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path." - /> - <Box padding="15px"></Box> <TextField - label="Quantization (Optional)" + sx={{ width: '400px' }} + label="Model Id" size="small" - value={quantization} + value={modelId} onChange={(e) => { - setQuantization(e.target.value) + setModelId(e.target.value) }} - helperText="For GPTQ/AWQ models, please be careful to fill in the quantization corresponding to the model you want to register." + placeholder="user/repo" /> - <Box padding="15px"></Box> - <TextField - label="Model Description (Optional)" - error={errorModelDescription} - defaultValue={formData.model_description} - size="small" - onChange={(event) => - setFormData({ - ...formData, - model_description: event.target.value, - }) - } - /> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - color: errorLanguage ? ERROR_COLOR : 'inherit', - }} + <FormControl variant="standard" + sx={{ marginLeft: '10px' }}> + <InputLabel id="hub-label">Hub</InputLabel> + <Select + labelId="hub-label" + value={hub} + label="Hub" + onChange={(e) => { + setHub(e.target.value) + }} + > + {SUPPORTED_HUBS.map((item) => ( + <MenuItem value={item}>{SUPPORTED_HUBS_DICT[item]}</MenuItem> + ))} + </Select> + </FormControl> + <Button + sx={{ marginLeft: '10px' }} + variant="contained" + color="primary" + onClick={handleImportModel} > - Model Languages - </label> - <Box sx={styles.checkboxWrapper}> - {SUPPORTED_LANGUAGES.map((lang) => ( - <Box key={lang} sx={{ marginRight: '10px' }}> - <FormControlLabel - control={ - <Checkbox - checked={formData.model_lang.includes(lang)} - onChange={() => toggleLanguage(lang)} - name={lang} - sx={ - errorLanguage - ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } - : {} + Import Model + </Button> + </Box> + } + <Box padding="15px"></Box> + + + <TextField + error={errorContextLength} + label="Context Length" + value={formData.context_length} + size="small" + onChange={(event) => { + let value = event.target.value + // Remove leading zeros + if (/^0+/.test(value)) { + value = value.replace(/^0+/, '') || '0' + } + // Ensure it's a positive integer, if not set it to the minimum + if (!/^\d+$/.test(value) || parseInt(value) < 0) { + value = '0' + } + // Update with the processed value + setFormData({ + ...formData, + context_length: Number(value), + }) + }} + /> + <Box padding="15px"></Box> + + <TextField + label="Model Size in Billions" + size="small" + error={errorModelSize} + value={modelSize} + onChange={(e) => { + let value = e.target.value + // Remove leading zeros + if (/^0+/.test(value)) { + value = value.replace(/^0+/, '') || '0' + } + // Ensure it's a positive integer, if not set it to the minimum + if (!/^\d+$/.test(value) || parseInt(value) < 0) { + value = '0' + } + setModelSize(Number(value)) + }} + /> + <Box padding="15px"></Box> + + + <TextField + label="Quantization (Optional)" + size="small" + value={quantization} + onChange={(e) => { + setQuantization(e.target.value) + }} + helperText="For GPTQ/AWQ models, please be careful to fill in the quantization corresponding to the model you want to register." + /> + <Box padding="15px"></Box> + + <TextField + label="Model Description (Optional)" + error={errorModelDescription} + defaultValue={formData.model_description} + size="small" + onChange={(event) => + setFormData({ + ...formData, + model_description: event.target.value, + }) + } + /> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + > + Model Languages + </label> + <Box sx={styles.checkboxWrapper}> + {SUPPORTED_LANGUAGES.map((lang) => ( + <Box key={lang} sx={{ marginRight: '10px' }}> + <FormControlLabel + control={ + <Checkbox + checked={formData.model_lang.includes(lang)} + onChange={() => toggleLanguage(lang)} + name={lang} + sx={ + errorLanguage + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, } - /> + : {} } - label={SUPPORTED_LANGUAGES_DICT[lang]} - style={{ - paddingLeft: 10, - color: errorLanguage ? ERROR_COLOR : 'inherit', - }} /> - </Box> - ))} + } + label={SUPPORTED_LANGUAGES_DICT[lang]} + style={{ + paddingLeft: 10, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + /> </Box> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - > - Model Abilities - </label> - <Box sx={styles.checkboxWrapper}> - {SUPPORTED_FEATURES.map((ability) => ( - <Box key={ability} sx={{ marginRight: '10px' }}> - <FormControlLabel - control={ - <Checkbox - checked={formData.model_ability.includes( - ability.toLowerCase() - )} - onChange={() => toggleAbility(ability.toLowerCase())} - name={ability} - sx={ - errorAbility - ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } - : {} + ))} + </Box> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + > + Model Abilities + </label> + <Box sx={styles.checkboxWrapper}> + {SUPPORTED_FEATURES.map((ability) => ( + <Box key={ability} sx={{ marginRight: '10px' }}> + <FormControlLabel + control={ + <Checkbox + checked={formData.model_ability.includes( + ability.toLowerCase(), + )} + onChange={() => toggleAbility(ability.toLowerCase())} + name={ability} + sx={ + errorAbility + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, } - /> + : {} } - label={ability} - style={{ - paddingLeft: 10, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} /> - </Box> - ))} + } + label={ability} + style={{ + paddingLeft: 10, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + /> </Box> - <Box padding="15px"></Box> - </FormControl> - - <FormControl sx={styles.baseFormControl}> - <label - style={{ - paddingLeft: 5, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - > - Model Family - </label> - <FormHelperText> - Please be careful to select the family name corresponding to the - model you want to register. If not found, please choose `other`. - </FormHelperText> - <RadioGroup - value={familyLabel} - onChange={(e) => { - setFamilyLabel(e.target.value) - }} - > - <Box sx={styles.checkboxWrapper}> - {getFamilyByAbility().map((v) => ( - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel value={v} control={<Radio />} label={v} /> - </Box> - ))} + ))} + </Box> + <Box padding="15px"></Box> + </FormControl> + + <FormControl sx={styles.baseFormControl}> + <label + style={{ + paddingLeft: 5, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + > + Model Family + </label> + <FormHelperText> + Please be careful to select the family name corresponding to the + model you want to register. If not found, please choose `other`. + </FormHelperText> + <RadioGroup + value={familyLabel} + onChange={(e) => { + setFamilyLabel(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + {getFamilyByAbility().map((v) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel value={v} control={<Radio />} label={v} /> </Box> - </RadioGroup> - <Box padding="15px"></Box> - </FormControl> - - <Box width={'100%'}> - {successMsg !== '' && ( - <Alert severity="success"> - <AlertTitle>Success</AlertTitle> - {successMsg} - </Alert> - )} - <Button - variant="contained" - color="primary" - type="submit" - onClick={handleClick} - > - Register Model - </Button> + ))} </Box> - </React.Fragment> + </RadioGroup> + <Box padding="15px"></Box> + </FormControl> + + <Box width={'100%'}> + {successMsg !== '' && ( + <Alert severity="success"> + <AlertTitle>Success</AlertTitle> + {successMsg} + </Alert> + )} + <Button + variant="contained" + color="primary" + type="submit" + onClick={handleClick} + > + Register Model + </Button> + </Box> + </React.Fragment> ) } From cc903e1e5f27e14da7866102d5e434c589b34063 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Fri, 15 Mar 2024 23:12:53 +0800 Subject: [PATCH 03/21] add more unit tests for getting model info. --- xinference/core/tests/test_utils.py | 21 +++++++++++++++++++++ xinference/core/utils.py | 6 ++++++ 2 files changed, 27 insertions(+) diff --git a/xinference/core/tests/test_utils.py b/xinference/core/tests/test_utils.py index f4adbaf629..d03e143389 100644 --- a/xinference/core/tests/test_utils.py +++ b/xinference/core/tests/test_utils.py @@ -229,3 +229,24 @@ def test_get_llama_cpp_quantization_info(): tpl1, tpl2 = get_llama_cpp_quantization_info(filenames, "ggufv2") assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf" assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.test-split-{part}.gguf" + + filenames = [ + "kafkalm-70b-german-v0.1.Q2_K.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_L.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q4_0.test.gguf", + "kafkalm-70b-german-v0.1.Q4_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q4_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q5_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q5_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q6_K.test.gguf-part1of2", + "kafkalm-70b-german-v0.1.Q6_K.test.gguf-part2of2", + "kafkalm-70b-german-v0.1.Q8_0.test.gguf-part1of3", + "kafkalm-70b-german-v0.1.Q8_0.test.gguf-part2of3", + "kafkalm-70b-german-v0.1.Q8_0.test.gguf-part3of3", + ] + + tpl1, tpl2 = get_llama_cpp_quantization_info(filenames, "ggufv2") + assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf" + assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf-part{part}" diff --git a/xinference/core/utils.py b/xinference/core/utils.py index 483296ca0c..a2a4aa7e95 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -299,6 +299,12 @@ def get_prefix_suffix(names: Iterable[str]) -> Tuple[str, str]: def get_llama_cpp_quantization_info( filenames: List[str], model_type: Literal["ggmlv3", "ggufv2"] ) -> Tuple[Optional[str], Optional[str]]: + """ + Get the model file name template and split template from a list of filenames. + + NOTE: not support multiple quantization files in multi-part zip files. + for example: a-16b.ggmlv3.zip a-16b.ggmlv3.z01 a-16b.ggmlv3.z02 are not supported + """ model_file_name_template = None model_file_name_split_template: Optional[str] = None if model_type == "ggmlv3": From c108eb05249f2d1d7ca94e5f4186136b378f6db7 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Sun, 17 Mar 2024 17:44:01 +0800 Subject: [PATCH 04/21] fix the get model info correctly. --- xinference/core/supervisor.py | 18 ++++- xinference/core/tests/test_utils.py | 120 +++++++++++++++++++++++----- xinference/core/utils.py | 111 ++++++++++++++++++------- 3 files changed, 199 insertions(+), 50 deletions(-) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 7c8e4739ec..9d39d1c2b5 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -41,6 +41,7 @@ from .utils import ( build_replica_model_uid, gen_random_string, + get_llama_cpp_quantization_info, get_model_size_from_model_id, is_valid_model_uid, iter_replica_model_uid, @@ -662,8 +663,8 @@ async def launch_speculative_llm( model_uid = self._gen_model_uid(model_name) logger.debug( ( - f"Enter launch_speculative_llm, model_uid: %s, model_name: %s, model_size: %s, " - f"draft_model_name: %s, draft_model_size: %s" + "Enter launch_speculative_llm, model_uid: %s, model_name: %s, model_size: %s, " + "draft_model_name: %s, draft_model_size: %s" ), model_uid, model_name, @@ -1023,6 +1024,19 @@ async def get_llm_spec( llm_spec.model_hub = model_hub llm_spec.model_size_in_billions = get_model_size_from_model_id(model_id) + filenames = await asyncio.wrap_future( + api.run_as_future(api.list_repo_files, model_id) + ) + + ( + llm_spec.model_file_name_template, + llm_spec.model_file_name_split_template, + llm_spec.quantizations, + llm_spec.quantization_parts, + ) = get_llama_cpp_quantization_info( + filenames, typing.cast(Literal["ggmlv3", "ggufv2"], model_format) + ) + else: raise ValueError(f"Unsupported model format: {model_format}") diff --git a/xinference/core/tests/test_utils.py b/xinference/core/tests/test_utils.py index d03e143389..696de7c5b6 100644 --- a/xinference/core/tests/test_utils.py +++ b/xinference/core/tests/test_utils.py @@ -39,55 +39,55 @@ def test_replica_model_uid(): def test_get_model_size_from_model_id(): model_id = "froggeric/WestLake-10.7B-v2-GGUF" model_size = get_model_size_from_model_id(model_id) - assert model_size == "10.7B" + assert model_size == 10.7 model_id = "m-a-p/OpenCodeInterpreter-DS-33B" model_size = get_model_size_from_model_id(model_id) - assert model_size == "33B" + assert model_size == 33 model_id = "MBZUAI/MobiLlama-05B" model_size = get_model_size_from_model_id(model_id) - assert model_size == "0.5B" + assert model_size == 0.5 model_id = "ibivibiv/alpaca-dragon-72b-v1" model_size = get_model_size_from_model_id(model_id) - assert model_size == "72B" + assert model_size == 72 model_id = "ISTA-DASLab/Mixtral-8x7B-Instruct-v0_1-AQLM-2Bit-1x16-hf" model_size = get_model_size_from_model_id(model_id) - assert model_size == "7B" + assert model_size == 7 model_id = "internlm/internlm-xcomposer2-vl-7b-4bit" model_size = get_model_size_from_model_id(model_id) - assert model_size == "7B" + assert model_size == 7 model_id = "ahxt/LiteLlama-460M-1T" model_size = get_model_size_from_model_id(model_id) - assert model_size == "0.46B" + assert model_size == 0.46 model_id = "Dracones/Midnight-Miqu-70B-v1.0_exl2_2.24bpw" model_size = get_model_size_from_model_id(model_id) - assert model_size == "70B" + assert model_size == 70 model_id = "MaziyarPanahi/MixTAO-7Bx2-MoE-v8.1-GGUF" model_size = get_model_size_from_model_id(model_id) - assert model_size == "7B" + assert model_size == 7 model_id = "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf" model_size = get_model_size_from_model_id(model_id) - assert model_size == "7B" + assert model_size == 7 model_id = "stabilityai/stablelm-2-zephyr-1_6b" model_size = get_model_size_from_model_id(model_id) - assert model_size == "1.6B" + assert model_size == "1_6" model_id = "Qwen/Qwen1.5-Chat-4bit-GPTQ-72B" model_size = get_model_size_from_model_id(model_id) - assert model_size == "72B" + assert model_size == 72 model_id = "m-a-p/OpenCodeInterpreter-3Bee-DS-33B" model_size = get_model_size_from_model_id(model_id) - assert model_size == "33B" + assert model_size == 33 model_id = "mlx-community/c4ai-command-r-v01-4bit" model_size = get_model_size_from_model_id(model_id) @@ -202,13 +202,45 @@ def test_get_llama_cpp_quantization_info(): "kafkalm-70b-german-v0.1.Q8_0.gguf-split-b", ] - tpl1, tpl2 = get_llama_cpp_quantization_info(filenames[:-4], "ggufv2") + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames[:-4], "ggufv2") assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.gguf" assert tpl2 is None - - tpl1, tpl2 = get_llama_cpp_quantization_info(filenames, "ggufv2") + assert len(qs) == 9 + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_K_M", + "Q5_K_S", + }.intersection(set(qs)) == set(qs) + assert parts is None + + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames, "ggufv2") assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.gguf" assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.gguf-split-{part}" + assert len(qs) == 11 + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + }.intersection(set(qs)) == set(qs) + assert len(parts) == 2 + assert len(parts["Q6_K"]) == 2 + assert len(parts["Q8_0"]) == 2 + assert parts["Q6_K"][0] == "a" + assert parts["Q8_0"][1] == "b" filenames = [ "kafkalm-70b-german-v0.1.Q2_K.test.gguf", @@ -226,9 +258,25 @@ def test_get_llama_cpp_quantization_info(): "kafkalm-70b-german-v0.1.Q8_0.test-split-b.gguf", ] - tpl1, tpl2 = get_llama_cpp_quantization_info(filenames, "ggufv2") + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames, "ggufv2") assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf" assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.test-split-{part}.gguf" + assert len(qs) == 11 + assert len(parts) == 2 + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + }.intersection(set(qs)) == set(qs) + assert parts["Q8_0"][1] == "b" filenames = [ "kafkalm-70b-german-v0.1.Q2_K.test.gguf", @@ -247,6 +295,42 @@ def test_get_llama_cpp_quantization_info(): "kafkalm-70b-german-v0.1.Q8_0.test.gguf-part3of3", ] - tpl1, tpl2 = get_llama_cpp_quantization_info(filenames, "ggufv2") + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames, "ggufv2") assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf" assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf-part{part}" + assert len(qs) == 11 + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + }.intersection(set(qs)) == set(qs) + assert len(parts) == 2 + assert len(parts["Q8_0"]) == 3 + assert parts["Q8_0"][2] == "3of3" + + filenames = [ + "llama-2-7b-chat.ggmlv3.q2_K.bin", + "llama-2-7b-chat.ggmlv3.q3_K_L.bin", + "llama-2-7b-chat.ggmlv3.q3_K_M.bin", + "llama-2-7b-chat.ggmlv3.q3_K_S.bin", + "llama-2-7b-chat.ggmlv3.q4_0.bin", + "llama-2-7b-chat.ggmlv3.q4_K_M.bin", + "llama-2-7b-chat.ggmlv3.q4_K_S.bin", + "llama-2-7b-chat.ggmlv3.q5_K_M.bin", + "llama-2-7b-chat.ggmlv3.q5_K_S.bin", + ] + + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames, "ggmlv3") + + assert tpl1 == "llama-2-7b-chat.ggmlv3.{quantization}.bin" + assert tpl2 is None + assert len(qs) == 9 + assert parts is None diff --git a/xinference/core/utils.py b/xinference/core/utils.py index a2a4aa7e95..35b9e4c668 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -17,7 +17,7 @@ import random import re import string -from typing import Dict, Generator, Iterable, List, Optional, Tuple, Union +from typing import Dict, Generator, Iterable, List, Optional, Tuple, Union, cast import orjson from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown @@ -195,7 +195,7 @@ def get_nvidia_gpu_info() -> Dict: pass -def get_model_size_from_model_id(model_id: str) -> str: +def get_model_size_from_model_id(model_id: str) -> Union[str, float, int]: """ Get model size from model_id. @@ -209,12 +209,24 @@ def get_model_size_from_model_id(model_id: str) -> str: If there is no model size in the repo name, return `UNKNOWN`. """ - def resize_to_billion(size: str) -> str: + def resize_to_billion(size: str) -> Union[str, int, float]: + if size == "UNKNOWN": + return size + if size.lower().endswith("m"): - return str(round(int(size[:-1]) / 1000, 2)).rstrip("0") + "B" - if size[0] == "0": - size = size[0] + "." + str(size[1:]) - return size.replace("_", ".").upper() + return round(int(size[:-1]) / 1000, 2) + + size = size[:-1] + if "_" not in size: + if size[0] == "0": + size = size[0] + "." + str(size[1:]) + + if "." in size: + return float(size) + else: + return int(size) + + return size split = model_id.split("/") if len(split) != 2: @@ -260,6 +272,12 @@ def resize_to_billion(size: str) -> str: def get_match_quantization_filenames( filenames: List[str], ) -> List[Tuple[str, str, int]]: + """ + Get the quantization info from filenames. + + Return: + A list of tuples: (filename, quantization, index of the quantization in filename) + """ results: List[Tuple[str, str, int]] = [] for filename in filenames: for quantization in SUPPORTED_QUANTIZATIONS: @@ -269,6 +287,9 @@ def get_match_quantization_filenames( def get_prefix_suffix(names: Iterable[str]) -> Tuple[str, str]: + """ + Get the common prefix and suffix from a list of names. + """ if len(list(names)) == 0: return "", "" @@ -298,15 +319,24 @@ def get_prefix_suffix(names: Iterable[str]) -> Tuple[str, str]: def get_llama_cpp_quantization_info( filenames: List[str], model_type: Literal["ggmlv3", "ggufv2"] -) -> Tuple[Optional[str], Optional[str]]: +) -> Tuple[Optional[str], Optional[str], List[str], Optional[Dict[str, List[str]]]]: """ Get the model file name template and split template from a list of filenames. NOTE: not support multiple quantization files in multi-part zip files. for example: a-16b.ggmlv3.zip a-16b.ggmlv3.z01 a-16b.ggmlv3.z02 are not supported + + Return: + model_file_name_template: the model file name with quantization info + model_file_name_split_template: the model file name with quantization info and part index + quantizations: the quantization info + parts: the quantization part index """ model_file_name_template = None model_file_name_split_template: Optional[str] = None + quantizations: List[str] = [] + parts: Optional[Dict[str, List[str]]] = None + if model_type == "ggmlv3": filenames = [ filename @@ -314,7 +344,7 @@ def get_llama_cpp_quantization_info( if filename.lower().endswith(".bin") or "ggml" in filename.lower() ] elif model_type == "ggufv2": - filenames = [filename for filename in filenames if ".gguf" in filename] + filenames = [filename for filename in filenames if ".gguf" in filename.lower()] else: raise ValueError(f"Unsupported model type: {model_type}") @@ -329,11 +359,20 @@ def get_llama_cpp_quantization_info( for filename, quantization, index in matched: prefixes.add(filename[:index]) suffixes.add(filename[index + len(quantization) :]) + if quantization not in quantizations: + quantizations.append(quantization) if len(prefixes) == 1 and len(suffixes) == 1: model_file_name_template = prefixes.pop() + "{quantization}" + suffixes.pop() - - elif len(prefixes) == 1 and len(suffixes) > 1: + return ( + model_file_name_template, + model_file_name_split_template, + quantizations, + parts, + ) + + if len(prefixes) == 1 and len(suffixes) > 1: + parts = {} shortest_suffix = min(suffixes, key=len) part_prefix, part_suffix = get_prefix_suffix(suffixes) if shortest_suffix == part_prefix + part_suffix: @@ -343,15 +382,13 @@ def get_llama_cpp_quantization_info( part_prefix, part_suffix = get_prefix_suffix( [suffix for suffix in suffixes if suffix != shortest_suffix] ) - model_file_name_split_template = ( - prefixes.pop() + "{quantization}" + part_prefix + "{part}" + part_suffix - ) - else: - model_file_name_split_template = ( - prefixes.pop() + "{quantization}" + part_prefix + "{part}" + part_suffix - ) + + model_file_name_split_template = ( + prefixes.pop() + "{quantization}" + part_prefix + "{part}" + part_suffix + ) elif len(prefixes) > 1 and len(suffixes) == 1: + parts = {} shortest_prefix = min(prefixes, key=len) part_prefix, part_suffix = get_prefix_suffix(prefixes) if shortest_prefix == part_prefix + part_suffix: @@ -361,18 +398,32 @@ def get_llama_cpp_quantization_info( part_prefix, part_suffix = get_prefix_suffix( [prefix for prefix in prefixes if prefix != shortest_prefix] ) - model_file_name_split_template = ( - part_prefix - + "{quantization}" - + shortest_prefix - + "{part}" - + part_suffix - ) - else: - model_file_name_split_template = ( - prefixes.pop() + "{quantization}" + part_prefix + "{part}" + part_suffix - ) + + model_file_name_split_template = ( + part_prefix + "{part}" + part_suffix + "{quantization}" + suffixes.pop() + ) else: logger.info("Cannot find a valid template for model file names") - return model_file_name_template, model_file_name_split_template + if model_file_name_split_template is not None: + part_pattern_str = model_file_name_split_template.replace( + "{part}", r"(?P<part>\w+)" + ) + quan_pattern_str = "(?P<quantization>" + f"{'|'.join(quantizations)})" + part_pattern_str = part_pattern_str.replace("{quantization}", quan_pattern_str) + + part_pattern = re.compile(part_pattern_str) + for filename in filenames: + if m := part_pattern.match(filename): + matched_quan = m.group("quantization") + parts = cast(Dict[str, List[str]], parts) + if matched_quan not in parts: + parts[matched_quan] = [] + parts[matched_quan].append(m.group("part")) + + return ( + model_file_name_template, + model_file_name_split_template, + quantizations, + parts, + ) From c24a459608d9246ab96e85b9841e92dc11dbc1d9 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Sun, 17 Mar 2024 21:30:48 +0800 Subject: [PATCH 05/21] fix the wrong llm family build logic. --- xinference/core/supervisor.py | 41 ++++++++++++++---------- xinference/core/tests/test_supervisor.py | 13 ++++++++ xinference/model/llm/llm_family.py | 2 +- 3 files changed, 38 insertions(+), 18 deletions(-) create mode 100644 xinference/core/tests/test_supervisor.py diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 9d39d1c2b5..2082e32c1d 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -35,7 +35,7 @@ from ..core import ModelActor from ..core.status_guard import InstanceInfo, LaunchStatus from ..model.llm import GgmlLLMSpecV1 -from ..model.llm.llm_family import HubImportLLMFamilyV1 +from ..model.llm.llm_family import DEFAULT_CONTEXT_LENGTH, HubImportLLMFamilyV1 from .metrics import record_metrics from .resource import GPUStatus, ResourceStatus from .utils import ( @@ -994,15 +994,14 @@ async def get_llm_spec( model_format: Literal["pytorch", "ggmlv3", "ggufv2", "gptq", "awq"], model_hub: str, ) -> HubImportLLMFamilyV1: - llm_family = HubImportLLMFamilyV1(version=1) if model_hub == "huggingface": api = self.__get_hf_api() - model_info: ModelInfo = await asyncio.wrap_future( api.run_as_future(api.model_info, model_id) ) logger.info(f"Model info: {model_info}") + context_length = DEFAULT_CONTEXT_LENGTH if await asyncio.wrap_future( api.run_as_future(api.file_exists, model_id, "config.json") ): @@ -1012,37 +1011,45 @@ async def get_llm_spec( with open(config_path) as f: config = json.load(f) if "max_position_embeddings" in config: - llm_family.context_length = config["max_position_embeddings"] + context_length = config["max_position_embeddings"] if model_format in ["pytorch", "gptq", "awq"]: pass elif model_format in ["ggmlv3", "ggufv2"]: - llm_spec = GgmlLLMSpecV1() - llm_family.model_specs.append(llm_spec) - llm_spec.model_id = model_id - llm_spec.model_format = model_format - llm_spec.model_hub = model_hub - llm_spec.model_size_in_billions = get_model_size_from_model_id(model_id) - filenames = await asyncio.wrap_future( api.run_as_future(api.list_repo_files, model_id) ) ( - llm_spec.model_file_name_template, - llm_spec.model_file_name_split_template, - llm_spec.quantizations, - llm_spec.quantization_parts, + model_file_name_template, + model_file_name_split_template, + quantizations, + quantization_parts, ) = get_llama_cpp_quantization_info( filenames, typing.cast(Literal["ggmlv3", "ggufv2"], model_format) ) + llm_spec = GgmlLLMSpecV1( + model_id=model_id, + model_format=model_format, + model_hub=model_hub, + quantizations=quantizations, + quantization_parts=quantization_parts, + model_size_in_billions=get_model_size_from_model_id(model_id), + model_file_name_template=model_file_name_template, + model_file_name_split_template=model_file_name_split_template, + ) + + return HubImportLLMFamilyV1( + version=1, context_length=context_length, model_specs=[llm_spec] + ) + else: raise ValueError(f"Unsupported model format: {model_format}") elif model_hub == "modelscope": - pass + raise NotImplementedError("modelscope not implemented") else: raise ValueError(f"Unsupported model hub: {model_hub}") - return llm_family + raise NotImplementedError() diff --git a/xinference/core/tests/test_supervisor.py b/xinference/core/tests/test_supervisor.py new file mode 100644 index 0000000000..3625ad58f3 --- /dev/null +++ b/xinference/core/tests/test_supervisor.py @@ -0,0 +1,13 @@ +import pytest + +from ..supervisor import SupervisorActor + + +@pytest.mark.asyncio +async def test_get_llm_spec(): + supervisor = SupervisorActor() + llm_family = await supervisor.get_llm_spec( + "TheBloke/Llama-2-7B-Chat-GGML", "ggmlv3", "huggingface" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index e2065fe1e4..d0e744fea0 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -125,7 +125,7 @@ class LLMFamilyV1(BaseModel): class HubImportLLMFamilyV1(BaseModel): version: Literal[1] context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH - model_specs: List["LLMSpecV1"] + model_specs: List["LLMSpecV1"] = [] prompt_style: Optional["PromptStyleV1"] From 4433f8510a6e6f746b6a8a5f931bf25add44fb4e Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 10:37:48 +0800 Subject: [PATCH 06/21] fix the bug that cannot process quantization with lower case. --- xinference/core/tests/test_supervisor.py | 67 ++++++++++++++++++++++++ xinference/core/utils.py | 5 +- 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/xinference/core/tests/test_supervisor.py b/xinference/core/tests/test_supervisor.py index 3625ad58f3..8c98fb241c 100644 --- a/xinference/core/tests/test_supervisor.py +++ b/xinference/core/tests/test_supervisor.py @@ -11,3 +11,70 @@ async def test_get_llm_spec(): ) assert llm_family is not None assert len(llm_family.model_specs) == 1 + assert llm_family.model_specs[0].model_id == "TheBloke/Llama-2-7B-Chat-GGML" + assert llm_family.model_specs[0].model_size_in_billions == 7 + assert llm_family.model_specs[0].model_hub == "huggingface" + assert len(llm_family.model_specs[0].quantizations) == 14 + assert ( + llm_family.model_specs[0].model_file_name_template + == "llama-2-7b-chat.ggmlv3.{quantization}.bin" + ) + assert llm_family.model_specs[0].model_file_name_split_template is None + assert llm_family.model_specs[0].quantization_parts is None + + assert { + "q2_K", + "q3_K_L", + "q3_K_M", + "q3_K_S", + "q4_0", + "q4_1", + "q4_K_M", + "q4_K_S", + "q5_0", + "q5_1", + "q5_K_M", + "q5_K_S", + "q6_K", + "q8_0", + }.intersection(set(llm_family.model_specs[0].quantizations)) == set( + llm_family.model_specs[0].quantizations + ) + + llm_family = await supervisor.get_llm_spec( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "ggufv2", "huggingface" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + assert llm_family.model_specs[0].model_id == "TheBloke/KafkaLM-70B-German-V0.1-GGUF" + assert llm_family.model_specs[0].model_size_in_billions == 70 + assert llm_family.model_specs[0].model_hub == "huggingface" + qs = llm_family.model_specs[0].quantizations + assert len(qs) == 12 + assert ( + llm_family.model_specs[0].model_file_name_template + == "kafkalm-70b-german-v0.1.{quantization}.gguf" + ) + assert ( + llm_family.model_specs[0].model_file_name_split_template + == "kafkalm-70b-german-v0.1.{quantization}.gguf-split-{part}" + ) + parts = llm_family.model_specs[0].quantization_parts + assert parts is not None + assert len(parts) == 2 + assert len(parts["Q8_0"]) == 2 + + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + }.intersection(set(qs)) == set(qs) diff --git a/xinference/core/utils.py b/xinference/core/utils.py index 35b9e4c668..546a61138a 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -359,8 +359,9 @@ def get_llama_cpp_quantization_info( for filename, quantization, index in matched: prefixes.add(filename[:index]) suffixes.add(filename[index + len(quantization) :]) - if quantization not in quantizations: - quantizations.append(quantization) + q = filename[index : index + len(quantization)] + if q not in quantizations: + quantizations.append(q) if len(prefixes) == 1 and len(suffixes) == 1: model_file_name_template = prefixes.pop() + "{quantization}" + suffixes.pop() From 53ca7581522667a0ea9cc3ff6dec14ee185a47c1 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 12:28:08 +0800 Subject: [PATCH 07/21] add support for GGUF, GGML of ModelScope --- xinference/core/supervisor.py | 90 +++++++++++++++++++---- xinference/core/tests/test_supervisor.py | 91 +++++++++++++++++++++++- 2 files changed, 167 insertions(+), 14 deletions(-) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 2082e32c1d..1118694d03 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -23,7 +23,10 @@ import xoscar as xo from huggingface_hub import HfApi -from huggingface_hub.hf_api import ModelInfo +from modelscope import HubApi +from modelscope.hub.errors import NotExistError +from modelscope.hub.file_download import model_file_download +from requests import HTTPError from typing_extensions import Literal from ..constants import ( @@ -86,6 +89,7 @@ class SupervisorActor(xo.StatelessActor): def __init__(self): super().__init__() self.__hf_api: Optional[HfApi] = None + self.__ms_api: Optional[HubApi] = None self._worker_address_to_worker: Dict[str, xo.ActorRefType["WorkerActor"]] = {} self._worker_status: Dict[str, WorkerStatus] = {} self._replica_model_uid_to_worker: Dict[ @@ -987,6 +991,11 @@ def __get_hf_api(self): self.__hf_api = HfApi() return self.__hf_api + def __get_ms_api(self): + if self.__ms_api is None: + self.__ms_api = HubApi() + return self.__ms_api + @log_async(logger=logger) async def get_llm_spec( self, @@ -994,19 +1003,23 @@ async def get_llm_spec( model_format: Literal["pytorch", "ggmlv3", "ggufv2", "gptq", "awq"], model_hub: str, ) -> HubImportLLMFamilyV1: + hf_api = self.__get_hf_api() + context_length = DEFAULT_CONTEXT_LENGTH + if model_hub == "huggingface": - api = self.__get_hf_api() - model_info: ModelInfo = await asyncio.wrap_future( - api.run_as_future(api.model_info, model_id) + repo_exists = await asyncio.wrap_future( + hf_api.run_as_future(hf_api.repo_exists, model_id) ) - logger.info(f"Model info: {model_info}") + if not repo_exists: + raise ValueError(f"Model {model_id} does not exist") - context_length = DEFAULT_CONTEXT_LENGTH if await asyncio.wrap_future( - api.run_as_future(api.file_exists, model_id, "config.json") + hf_api.run_as_future(hf_api.file_exists, model_id, "config.json") ): config_path = await asyncio.wrap_future( - api.run_as_future(api.hf_hub_download, model_id, "config.json") + hf_api.run_as_future( + hf_api.hf_hub_download, model_id, "config.json" + ) ) with open(config_path) as f: config = json.load(f) @@ -1014,10 +1027,10 @@ async def get_llm_spec( context_length = config["max_position_embeddings"] if model_format in ["pytorch", "gptq", "awq"]: - pass + raise NotImplementedError("pytorch, gptq and awq not implemented yet") elif model_format in ["ggmlv3", "ggufv2"]: filenames = await asyncio.wrap_future( - api.run_as_future(api.list_repo_files, model_id) + hf_api.run_as_future(hf_api.list_repo_files, model_id) ) ( @@ -1048,8 +1061,59 @@ async def get_llm_spec( raise ValueError(f"Unsupported model format: {model_format}") elif model_hub == "modelscope": - raise NotImplementedError("modelscope not implemented") + ms_api = self.__get_ms_api() + try: + await asyncio.wrap_future( + hf_api.run_as_future(ms_api.get_model, model_id) + ) + except NotExistError: + raise ValueError(f"Model {model_id} does not exist") + except HTTPError: + raise ValueError(f"Model {model_id} does not exist") + + try: + config_path = await asyncio.wrap_future( + hf_api.run_as_future(model_file_download, model_id, "config.json") + ) + if config_path is not None: + with open(config_path) as f: + config = json.load(f) + if "max_position_embeddings" in config: + context_length = config["max_position_embeddings"] + except NotExistError: + logger.warning(f"Model {model_id} does not have config file") + + if model_format in ["pytorch", "gptq", "awq"]: + raise NotImplementedError("pytorch, gptq and awq not implemented yet") + elif model_format in ["ggmlv3", "ggufv2"]: + file_infos = await asyncio.wrap_future( + hf_api.run_as_future(ms_api.get_model_files, model_id) + ) + filenames = [info["Path"] for info in file_infos] + ( + model_file_name_template, + model_file_name_split_template, + quantizations, + quantization_parts, + ) = get_llama_cpp_quantization_info( + filenames, typing.cast(Literal["ggmlv3", "ggufv2"], model_format) + ) + + llm_spec = GgmlLLMSpecV1( + model_id=model_id, + model_format=model_format, + model_hub=model_hub, + quantizations=quantizations, + quantization_parts=quantization_parts, + model_size_in_billions=get_model_size_from_model_id(model_id), + model_file_name_template=model_file_name_template, + model_file_name_split_template=model_file_name_split_template, + ) + + return HubImportLLMFamilyV1( + version=1, context_length=context_length, model_specs=[llm_spec] + ) + else: + raise ValueError(f"Unsupported model format: {model_format}") else: raise ValueError(f"Unsupported model hub: {model_hub}") - - raise NotImplementedError() diff --git a/xinference/core/tests/test_supervisor.py b/xinference/core/tests/test_supervisor.py index 8c98fb241c..f26976d10b 100644 --- a/xinference/core/tests/test_supervisor.py +++ b/xinference/core/tests/test_supervisor.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_get_llm_spec(): +async def test_get_llm_spec_hf(): supervisor = SupervisorActor() llm_family = await supervisor.get_llm_spec( "TheBloke/Llama-2-7B-Chat-GGML", "ggmlv3", "huggingface" @@ -78,3 +78,92 @@ async def test_get_llm_spec(): "Q6_K", "Q8_0", }.intersection(set(qs)) == set(qs) + + try: + llm_family = await supervisor.get_llm_spec( + "Nobody/No_This_Repo", "ggufv2", "huggingface" + ) + assert False + except ValueError as e: + assert str(e) == "Model Nobody/No_This_Repo does not exist" + + +@pytest.mark.asyncio +async def test_get_llm_spec_ms(): + supervisor = SupervisorActor() + llm_family = await supervisor.get_llm_spec( + "Xorbits/Llama-2-7B-Chat-GGML", "ggmlv3", "modelscope" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + assert llm_family.model_specs[0].model_id == "Xorbits/Llama-2-7B-Chat-GGML" + assert llm_family.model_specs[0].model_size_in_billions == 7 + assert llm_family.model_specs[0].model_hub == "modelscope" + assert len(llm_family.model_specs[0].quantizations) == 14 + assert ( + llm_family.model_specs[0].model_file_name_template + == "llama-2-7b-chat.ggmlv3.{quantization}.bin" + ) + assert llm_family.model_specs[0].model_file_name_split_template is None + assert llm_family.model_specs[0].quantization_parts is None + + assert { + "q2_K", + "q3_K_L", + "q3_K_M", + "q3_K_S", + "q4_0", + "q4_1", + "q4_K_M", + "q4_K_S", + "q5_0", + "q5_1", + "q5_K_M", + "q5_K_S", + "q6_K", + "q8_0", + }.intersection(set(llm_family.model_specs[0].quantizations)) == set( + llm_family.model_specs[0].quantizations + ) + + llm_family = await supervisor.get_llm_spec( + "qwen/Qwen1.5-72B-Chat-GGUF", "ggufv2", "modelscope" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + assert llm_family.model_specs[0].model_id == "qwen/Qwen1.5-72B-Chat-GGUF" + assert llm_family.model_specs[0].model_size_in_billions == 72 + assert llm_family.model_specs[0].model_hub == "modelscope" + qs = llm_family.model_specs[0].quantizations + assert len(qs) == 8 + assert ( + llm_family.model_specs[0].model_file_name_template + == "qwen1_5-72b-chat-{quantization}.gguf" + ) + assert ( + llm_family.model_specs[0].model_file_name_split_template + == "qwen1_5-72b-chat-{quantization}.gguf.{part}" + ) + parts = llm_family.model_specs[0].quantization_parts + assert parts is not None + assert len(parts) == 6 + assert len(parts["q8_0"]) == 3 + + assert { + "q2_k", + "q3_k_m", + "q4_0", + "q4_k_m", + "q5_0", + "q5_k_m", + "q6_k", + "q8_0", + }.intersection(set(qs)) == set(qs) + + try: + llm_family = await supervisor.get_llm_spec( + "Nobody/No_This_Repo", "ggufv2", "modelscope" + ) + assert False + except ValueError as e: + assert str(e) == "Model Nobody/No_This_Repo does not exist" From 5dc38a150588f3e5c225e4b9f8130a7dac95d8f0 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 15:32:32 +0800 Subject: [PATCH 08/21] add async runner to run function async, and add model hub utility to get model hub infos. --- xinference/model/llm/tests/test_utils.py | 159 ++++++++++++++++++++++- xinference/model/llm/utils.py | 95 ++++++++++++++ xinference/tests/__init__.py | 13 ++ xinference/tests/test_utils.py | 43 ++++++ xinference/utils.py | 24 ++++ 5 files changed, 333 insertions(+), 1 deletion(-) create mode 100644 xinference/tests/__init__.py create mode 100644 xinference/tests/test_utils.py diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index d5e40d7561..b3e37224db 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + +import pytest from ....types import ChatCompletionMessage from ..llm_family import PromptStyleV1 -from ..utils import ChatModelMixin +from ..utils import ChatModelMixin, ModelHubUtil def test_prompt_style_add_colon_single(): @@ -421,3 +424,157 @@ def test_is_valid_model_name(): assert not is_valid_model_name("foo/bar") assert not is_valid_model_name(" ") assert not is_valid_model_name("") + + +@pytest.fixture +def model_hub_util(): + return ModelHubUtil() + + +def test__hf_api(model_hub_util): + assert model_hub_util._hf_api is not None + + +def test__ms_api(model_hub_util): + assert model_hub_util._ms_api is not None + + +def test_repo_exists(model_hub_util): + assert model_hub_util.repo_exists( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert not model_hub_util.repo_exists("Nobody/No_This_Repo", "huggingface") + try: + model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") + assert False + except ValueError: + assert True + + assert model_hub_util.repo_exists("qwen/Qwen1.5-72B-Chat-GGUF", "modelscope") + assert not model_hub_util.repo_exists("Nobody/No_This_Repo", "modelscope") + try: + model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") + assert False + except ValueError: + assert True + + +@pytest.mark.asyncio +async def test_a_repo_exists(model_hub_util): + assert await model_hub_util.a_repo_exists( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert not await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "huggingface") + try: + model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") + assert False + except ValueError: + assert True + + assert await model_hub_util.a_repo_exists( + "qwen/Qwen1.5-72B-Chat-GGUF", "modelscope" + ) + assert not await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "modelscope") + try: + await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "unknown_hub") + assert False + except ValueError: + assert True + + +def test_get_config_path(model_hub_util): + p = model_hub_util.get_config_path( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert p is not None + assert os.path.isfile(p) + + assert model_hub_util.get_config_path("Nobody/No_This_Repo", "huggingface") is None + + p = model_hub_util.get_config_path("qwen/Qwen1.5-72B-Chat-GGUF", "modelscope") + assert p is None + + p = model_hub_util.get_config_path("deepseek-ai/deepseek-vl-7b-chat", "modelscope") + assert p is not None + assert os.path.isfile(p) + + assert model_hub_util.get_config_path("Nobody/No_This_Repo", "modelscope") is None + + +@pytest.mark.asyncio +async def test_a_get_config_path_async(model_hub_util): + p = await model_hub_util.a_get_config_path( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert p is not None + assert os.path.isfile(p) + + assert ( + await model_hub_util.a_get_config_path("Nobody/No_This_Repo", "huggingface") + is None + ) + + p = await model_hub_util.a_get_config_path( + "qwen/Qwen1.5-72B-Chat-GGUF", "modelscope" + ) + assert p is None + + p = await model_hub_util.a_get_config_path( + "deepseek-ai/deepseek-vl-7b-chat", "modelscope" + ) + assert p is not None + assert os.path.isfile(p) + + assert ( + await model_hub_util.a_get_config_path("Nobody/No_This_Repo", "modelscope") + is None + ) + + +def test_list_repo_files(model_hub_util): + files = model_hub_util.list_repo_files( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert len(files) == 20 + + files = model_hub_util.list_repo_files( + "deepseek-ai/deepseek-vl-7b-chat", "modelscope" + ) + assert len(files) == 12 # the `.gitattributes` file is not included + + try: + model_hub_util.list_repo_files("Nobody/No_This_Repo", "huggingface") + assert False + except ValueError as e: + assert str(e) == "Repository Nobody/No_This_Repo not found." + + try: + model_hub_util.list_repo_files("Nobody/No_This_Repo", "modelscope") + assert False + except ValueError as e: + assert str(e) == "Repository Nobody/No_This_Repo not found." + + +@pytest.mark.asyncio +async def test_a_list_repo_files(model_hub_util): + files = await model_hub_util.a_list_repo_files( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert len(files) == 20 + + files = await model_hub_util.a_list_repo_files( + "deepseek-ai/deepseek-vl-7b-chat", "modelscope" + ) + assert len(files) == 12 # the `.gitattributes` file is not included + + try: + await model_hub_util.a_list_repo_files("Nobody/No_This_Repo", "huggingface") + assert False + except ValueError as e: + assert str(e) == "Repository Nobody/No_This_Repo not found." + + try: + await model_hub_util.a_list_repo_files("Nobody/No_This_Repo", "modelscope") + assert False + except ValueError as e: + assert str(e) == "Repository Nobody/No_This_Repo not found." diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 9474c8191b..51eb67f441 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio.futures import functools import json import logging @@ -19,6 +20,14 @@ import uuid from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast +from huggingface_hub import HfApi +from huggingface_hub.utils import RepositoryNotFoundError +from modelscope import HubApi +from modelscope.hub.errors import NotExistError +from modelscope.hub.file_download import model_file_download +from requests import HTTPError +from typing_extensions import Literal + from ...types import ( SPECIAL_TOOL_PROMPT, ChatCompletion, @@ -27,6 +36,7 @@ Completion, CompletionChunk, ) +from ...utils import AsyncRunner from .llm_family import ( GgmlLLMSpecV1, LLMFamilyV1, @@ -655,3 +665,88 @@ def get_model_version( llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str ) -> str: return f"{llm_family.model_name}--{llm_spec.model_size_in_billions}B--{llm_spec.model_format}--{quantization}" + + +MODEL_HUB = Literal["huggingface", "modelscope"] + + +class ModelHubUtil(object): + def __init__(self): + self.__hf_api: Optional[HfApi] = None + self.__ms_api: Optional[HubApi] = None + self.__async_runner = AsyncRunner() + + @property + def _hf_api(self) -> HfApi: + if self.__hf_api is None: + self.__hf_api = HfApi() + return self.__hf_api + + @property + def _ms_api(self) -> HubApi: + if self.__ms_api is None: + self.__ms_api = HubApi() + return self.__ms_api + + def repo_exists(self, model_id: str, hub: MODEL_HUB) -> bool: + if hub == "huggingface": + return self._hf_api.repo_exists(model_id) + elif hub == "modelscope": + try: + self._ms_api.get_model(model_id) + return True + except (NotExistError, HTTPError): + return False + else: + raise ValueError("Unsupported model hub") + + async def a_repo_exists( + self, model_id: str, hub: MODEL_HUB + ) -> asyncio.Future[bool]: + return await self.__async_runner.async_run(self.repo_exists, model_id, hub) + + def get_config_path(self, model_id: str, hub: MODEL_HUB) -> Optional[str]: + if hub == "huggingface": + try: + return self._hf_api.hf_hub_download(model_id, "config.json") + except (ValueError, HTTPError) as e: + logging.error(e) + return None + elif hub == "modelscope": + try: + return model_file_download(model_id, "config.json") + except (NotExistError, HTTPError) as e: + logging.error(e) + return None + + async def a_get_config_path( + self, model_id: str, hub: MODEL_HUB + ) -> asyncio.Future[Optional[str]]: + return await self.__async_runner.async_run(self.get_config_path, model_id, hub) + + def list_repo_files(self, model_id: str, hub: MODEL_HUB) -> List[str]: + """ + List all files in the model repo. + + Notice: ModelScope does not return the hidden files which start with dot, + however, HuggingFace does. + """ + if hub == "huggingface": + try: + return self._hf_api.list_repo_files(model_id) + except RepositoryNotFoundError: + raise ValueError(f"Repository {model_id} not found.") + elif hub == "modelscope": + try: + return [ + entry["Path"] for entry in self._ms_api.get_model_files(model_id) + ] + except HTTPError: + raise ValueError(f"Repository {model_id} not found.") + else: + raise ValueError("Unsupported model hub") + + async def a_list_repo_files( + self, model_id: str, hub: MODEL_HUB + ) -> asyncio.Future[List[str]]: + return await self.__async_runner.async_run(self.list_repo_files, model_id, hub) diff --git a/xinference/tests/__init__.py b/xinference/tests/__init__.py new file mode 100644 index 0000000000..37f6558d95 --- /dev/null +++ b/xinference/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/tests/test_utils.py b/xinference/tests/test_utils.py new file mode 100644 index 0000000000..54b6057f96 --- /dev/null +++ b/xinference/tests/test_utils.py @@ -0,0 +1,43 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from concurrent.futures import Future + +import pytest +from typing_extensions import Coroutine + +from ..utils import AsyncRunner + + +@pytest.fixture +def async_runner(): + return AsyncRunner() + + +def test__thread_pool(async_runner): + assert async_runner._thread_pool is not None + + +def test_run_as_future(async_runner): + future = async_runner.run_as_future(lambda: 1) + assert isinstance(future, Future) + assert future.result() == 1 + + +def test_async_run(async_runner): + assert isinstance(async_runner.async_run(lambda: 1), Coroutine) + + +@pytest.mark.asyncio +async def test_async_run_a(async_runner): + assert await async_runner.async_run(lambda: 1) == 1 diff --git a/xinference/utils.py b/xinference/utils.py index 5b3741c222..dcc2f66b3c 100644 --- a/xinference/utils.py +++ b/xinference/utils.py @@ -11,10 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +from concurrent.futures import Future +from concurrent.futures.thread import ThreadPoolExecutor import torch +from typing_extensions import Callable, Optional, TypeVar def cuda_count(): # even if install torch cpu, this interface would return 0. return torch.cuda.device_count() + + +R = TypeVar("R") # Return type + + +class AsyncRunner(object): + def __init__(self): + self.__thread_pool: Optional[ThreadPoolExecutor] = None + + @property + def _thread_pool(self): + if self.__thread_pool is None: + self.__thread_pool = ThreadPoolExecutor(max_workers=1) + return self.__thread_pool + + def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]: + return self._thread_pool.submit(fn, *args, **kwargs) + + async def async_run(self, fn: Callable[..., R], *args, **kwargs) -> R: + return await asyncio.wrap_future(self.run_as_future(fn, *args, **kwargs)) From 3991ae610c5af9d1b56c6b7bcadf91030b16c7eb Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 15:46:50 +0800 Subject: [PATCH 09/21] Did the code refactoring to get gguf and ggml model info. --- xinference/core/supervisor.py | 174 ++++++++++------------------------ xinference/model/llm/utils.py | 13 +-- 2 files changed, 53 insertions(+), 134 deletions(-) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 1118694d03..6a0faccb9f 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -22,12 +22,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union import xoscar as xo -from huggingface_hub import HfApi -from modelscope import HubApi -from modelscope.hub.errors import NotExistError -from modelscope.hub.file_download import model_file_download -from requests import HTTPError -from typing_extensions import Literal +from typing_extensions import Literal, cast from ..constants import ( XINFERENCE_DISABLE_HEALTH_CHECK, @@ -39,6 +34,7 @@ from ..core.status_guard import InstanceInfo, LaunchStatus from ..model.llm import GgmlLLMSpecV1 from ..model.llm.llm_family import DEFAULT_CONTEXT_LENGTH, HubImportLLMFamilyV1 +from ..model.llm.utils import MODEL_HUB, ModelHubUtil from .metrics import record_metrics from .resource import GPUStatus, ResourceStatus from .utils import ( @@ -88,8 +84,7 @@ class ReplicaInfo: class SupervisorActor(xo.StatelessActor): def __init__(self): super().__init__() - self.__hf_api: Optional[HfApi] = None - self.__ms_api: Optional[HubApi] = None + self._model_hub_util = ModelHubUtil() self._worker_address_to_worker: Dict[str, xo.ActorRefType["WorkerActor"]] = {} self._worker_status: Dict[str, WorkerStatus] = {} self._replica_model_uid_to_worker: Dict[ @@ -986,16 +981,6 @@ async def report_worker_status( def record_metrics(name, op, kwargs): record_metrics(name, op, kwargs) - def __get_hf_api(self): - if self.__hf_api is None: - self.__hf_api = HfApi() - return self.__hf_api - - def __get_ms_api(self): - if self.__ms_api is None: - self.__ms_api = HubApi() - return self.__ms_api - @log_async(logger=logger) async def get_llm_spec( self, @@ -1003,117 +988,58 @@ async def get_llm_spec( model_format: Literal["pytorch", "ggmlv3", "ggufv2", "gptq", "awq"], model_hub: str, ) -> HubImportLLMFamilyV1: - hf_api = self.__get_hf_api() - context_length = DEFAULT_CONTEXT_LENGTH - - if model_hub == "huggingface": - repo_exists = await asyncio.wrap_future( - hf_api.run_as_future(hf_api.repo_exists, model_id) - ) - if not repo_exists: - raise ValueError(f"Model {model_id} does not exist") - - if await asyncio.wrap_future( - hf_api.run_as_future(hf_api.file_exists, model_id, "config.json") - ): - config_path = await asyncio.wrap_future( - hf_api.run_as_future( - hf_api.hf_hub_download, model_id, "config.json" - ) - ) - with open(config_path) as f: - config = json.load(f) - if "max_position_embeddings" in config: - context_length = config["max_position_embeddings"] - - if model_format in ["pytorch", "gptq", "awq"]: - raise NotImplementedError("pytorch, gptq and awq not implemented yet") - elif model_format in ["ggmlv3", "ggufv2"]: - filenames = await asyncio.wrap_future( - hf_api.run_as_future(hf_api.list_repo_files, model_id) - ) - - ( - model_file_name_template, - model_file_name_split_template, - quantizations, - quantization_parts, - ) = get_llama_cpp_quantization_info( - filenames, typing.cast(Literal["ggmlv3", "ggufv2"], model_format) - ) + if model_hub not in ["huggingface", "modelscope"]: + raise ValueError(f"Unsupported model hub: {model_hub}") - llm_spec = GgmlLLMSpecV1( - model_id=model_id, - model_format=model_format, - model_hub=model_hub, - quantizations=quantizations, - quantization_parts=quantization_parts, - model_size_in_billions=get_model_size_from_model_id(model_id), - model_file_name_template=model_file_name_template, - model_file_name_split_template=model_file_name_split_template, - ) + model_hub = cast(MODEL_HUB, model_hub) - return HubImportLLMFamilyV1( - version=1, context_length=context_length, model_specs=[llm_spec] - ) + context_length = DEFAULT_CONTEXT_LENGTH - else: - raise ValueError(f"Unsupported model format: {model_format}") + repo_exists = await self._model_hub_util.a_repo_exists( + model_id, + model_hub, + ) + if not repo_exists: + raise ValueError(f"Model {model_id} does not exist") + + if config_path := await self._model_hub_util.a_get_config_path( + model_id, model_hub + ): + with open(config_path) as f: + config = json.load(f) + if "max_position_embeddings" in config: + context_length = config["max_position_embeddings"] + + if model_format in ["pytorch", "gptq", "awq"]: + raise NotImplementedError("pytorch, gptq and awq not implemented yet") + elif model_format in ["ggmlv3", "ggufv2"]: + filenames = await self._model_hub_util.a_list_repo_files( + model_id, model_hub + ) - elif model_hub == "modelscope": - ms_api = self.__get_ms_api() - try: - await asyncio.wrap_future( - hf_api.run_as_future(ms_api.get_model, model_id) - ) - except NotExistError: - raise ValueError(f"Model {model_id} does not exist") - except HTTPError: - raise ValueError(f"Model {model_id} does not exist") + ( + model_file_name_template, + model_file_name_split_template, + quantizations, + quantization_parts, + ) = get_llama_cpp_quantization_info( + filenames, typing.cast(Literal["ggmlv3", "ggufv2"], model_format) + ) - try: - config_path = await asyncio.wrap_future( - hf_api.run_as_future(model_file_download, model_id, "config.json") - ) - if config_path is not None: - with open(config_path) as f: - config = json.load(f) - if "max_position_embeddings" in config: - context_length = config["max_position_embeddings"] - except NotExistError: - logger.warning(f"Model {model_id} does not have config file") - - if model_format in ["pytorch", "gptq", "awq"]: - raise NotImplementedError("pytorch, gptq and awq not implemented yet") - elif model_format in ["ggmlv3", "ggufv2"]: - file_infos = await asyncio.wrap_future( - hf_api.run_as_future(ms_api.get_model_files, model_id) - ) - filenames = [info["Path"] for info in file_infos] - ( - model_file_name_template, - model_file_name_split_template, - quantizations, - quantization_parts, - ) = get_llama_cpp_quantization_info( - filenames, typing.cast(Literal["ggmlv3", "ggufv2"], model_format) - ) + llm_spec = GgmlLLMSpecV1( + model_id=model_id, + model_format=model_format, + model_hub=model_hub, + quantizations=quantizations, + quantization_parts=quantization_parts, + model_size_in_billions=get_model_size_from_model_id(model_id), + model_file_name_template=model_file_name_template, + model_file_name_split_template=model_file_name_split_template, + ) - llm_spec = GgmlLLMSpecV1( - model_id=model_id, - model_format=model_format, - model_hub=model_hub, - quantizations=quantizations, - quantization_parts=quantization_parts, - model_size_in_billions=get_model_size_from_model_id(model_id), - model_file_name_template=model_file_name_template, - model_file_name_split_template=model_file_name_split_template, - ) + return HubImportLLMFamilyV1( + version=1, context_length=context_length, model_specs=[llm_spec] + ) - return HubImportLLMFamilyV1( - version=1, context_length=context_length, model_specs=[llm_spec] - ) - else: - raise ValueError(f"Unsupported model format: {model_format}") else: - raise ValueError(f"Unsupported model hub: {model_hub}") + raise ValueError(f"Unsupported model format: {model_format}") diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 51eb67f441..9888509ddb 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio.futures import functools import json import logging @@ -700,9 +699,7 @@ def repo_exists(self, model_id: str, hub: MODEL_HUB) -> bool: else: raise ValueError("Unsupported model hub") - async def a_repo_exists( - self, model_id: str, hub: MODEL_HUB - ) -> asyncio.Future[bool]: + async def a_repo_exists(self, model_id: str, hub: MODEL_HUB) -> bool: return await self.__async_runner.async_run(self.repo_exists, model_id, hub) def get_config_path(self, model_id: str, hub: MODEL_HUB) -> Optional[str]: @@ -719,9 +716,7 @@ def get_config_path(self, model_id: str, hub: MODEL_HUB) -> Optional[str]: logging.error(e) return None - async def a_get_config_path( - self, model_id: str, hub: MODEL_HUB - ) -> asyncio.Future[Optional[str]]: + async def a_get_config_path(self, model_id: str, hub: MODEL_HUB) -> Optional[str]: return await self.__async_runner.async_run(self.get_config_path, model_id, hub) def list_repo_files(self, model_id: str, hub: MODEL_HUB) -> List[str]: @@ -746,7 +741,5 @@ def list_repo_files(self, model_id: str, hub: MODEL_HUB) -> List[str]: else: raise ValueError("Unsupported model hub") - async def a_list_repo_files( - self, model_id: str, hub: MODEL_HUB - ) -> asyncio.Future[List[str]]: + async def a_list_repo_files(self, model_id: str, hub: MODEL_HUB) -> List[str]: return await self.__async_runner.async_run(self.list_repo_files, model_id, hub) From c0936e870ecc022b547070002404075179ac435d Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 16:45:10 +0800 Subject: [PATCH 10/21] when model size is decimals, replace dot with underscore. --- xinference/core/tests/test_utils.py | 6 +++--- xinference/core/utils.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/xinference/core/tests/test_utils.py b/xinference/core/tests/test_utils.py index 696de7c5b6..ad9364c394 100644 --- a/xinference/core/tests/test_utils.py +++ b/xinference/core/tests/test_utils.py @@ -39,7 +39,7 @@ def test_replica_model_uid(): def test_get_model_size_from_model_id(): model_id = "froggeric/WestLake-10.7B-v2-GGUF" model_size = get_model_size_from_model_id(model_id) - assert model_size == 10.7 + assert model_size == "10_7" model_id = "m-a-p/OpenCodeInterpreter-DS-33B" model_size = get_model_size_from_model_id(model_id) @@ -47,7 +47,7 @@ def test_get_model_size_from_model_id(): model_id = "MBZUAI/MobiLlama-05B" model_size = get_model_size_from_model_id(model_id) - assert model_size == 0.5 + assert model_size == "0_5" model_id = "ibivibiv/alpaca-dragon-72b-v1" model_size = get_model_size_from_model_id(model_id) @@ -63,7 +63,7 @@ def test_get_model_size_from_model_id(): model_id = "ahxt/LiteLlama-460M-1T" model_size = get_model_size_from_model_id(model_id) - assert model_size == 0.46 + assert model_size == "0_46" model_id = "Dracones/Midnight-Miqu-70B-v1.0_exl2_2.24bpw" model_size = get_model_size_from_model_id(model_id) diff --git a/xinference/core/utils.py b/xinference/core/utils.py index 546a61138a..d7430c7857 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -214,7 +214,7 @@ def resize_to_billion(size: str) -> Union[str, int, float]: return size if size.lower().endswith("m"): - return round(int(size[:-1]) / 1000, 2) + return str(round(int(size[:-1]) / 1000, 2)).replace(".", "_") size = size[:-1] if "_" not in size: @@ -222,7 +222,7 @@ def resize_to_billion(size: str) -> Union[str, int, float]: size = size[0] + "." + str(size[1:]) if "." in size: - return float(size) + return size.replace(".", "_") else: return int(size) From 5081c48f01489d1e06e1947a275c0b9e24780ddd Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 17:04:01 +0800 Subject: [PATCH 11/21] when model size is unknown, not return "UNKNOWN", but 0 --- xinference/core/tests/test_utils.py | 4 ++-- xinference/core/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xinference/core/tests/test_utils.py b/xinference/core/tests/test_utils.py index ad9364c394..c2f740cc2f 100644 --- a/xinference/core/tests/test_utils.py +++ b/xinference/core/tests/test_utils.py @@ -91,11 +91,11 @@ def test_get_model_size_from_model_id(): model_id = "mlx-community/c4ai-command-r-v01-4bit" model_size = get_model_size_from_model_id(model_id) - assert model_size == "UNKNOWN" + assert model_size == 0 model_id = "lemonilia/ShoriRP-v0.75d" model_size = get_model_size_from_model_id(model_id) - assert model_size == "UNKNOWN" + assert model_size == 0 model_id = "abc" try: diff --git a/xinference/core/utils.py b/xinference/core/utils.py index d7430c7857..5cf41e332a 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -211,7 +211,7 @@ def get_model_size_from_model_id(model_id: str) -> Union[str, float, int]: def resize_to_billion(size: str) -> Union[str, int, float]: if size == "UNKNOWN": - return size + return 0 if size.lower().endswith("m"): return str(round(int(size[:-1]) / 1000, 2)).replace(".", "_") From f0bd9afd49287059b879c469c7a1d1a86dd96457 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 17:05:08 +0800 Subject: [PATCH 12/21] add support of pytorch and awq format to fetch model info from hub. --- xinference/core/supervisor.py | 26 ++++++++--- xinference/core/tests/test_supervisor.py | 56 ++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 5 deletions(-) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 6a0faccb9f..2fc4a4388d 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -33,7 +33,11 @@ from ..core import ModelActor from ..core.status_guard import InstanceInfo, LaunchStatus from ..model.llm import GgmlLLMSpecV1 -from ..model.llm.llm_family import DEFAULT_CONTEXT_LENGTH, HubImportLLMFamilyV1 +from ..model.llm.llm_family import ( + DEFAULT_CONTEXT_LENGTH, + HubImportLLMFamilyV1, + PytorchLLMSpecV1, +) from ..model.llm.utils import MODEL_HUB, ModelHubUtil from .metrics import record_metrics from .resource import GPUStatus, ResourceStatus @@ -1010,9 +1014,7 @@ async def get_llm_spec( if "max_position_embeddings" in config: context_length = config["max_position_embeddings"] - if model_format in ["pytorch", "gptq", "awq"]: - raise NotImplementedError("pytorch, gptq and awq not implemented yet") - elif model_format in ["ggmlv3", "ggufv2"]: + if model_format in ["ggmlv3", "ggufv2"]: filenames = await self._model_hub_util.a_list_repo_files( model_id, model_hub ) @@ -1040,6 +1042,20 @@ async def get_llm_spec( return HubImportLLMFamilyV1( version=1, context_length=context_length, model_specs=[llm_spec] ) - + elif model_format in ["pytorch", "awq"]: + llm_spec = PytorchLLMSpecV1( + model_id=model_id, + model_format=model_format, + model_hub=model_hub, + model_size_in_billions=get_model_size_from_model_id(model_id), + quantizations=( + ["4-bit, 8-bit, none"] if model_format == "pytorch" else ["Int4"] + ), + ) + return HubImportLLMFamilyV1( + version=1, context_length=context_length, model_specs=[llm_spec] + ) + elif model_format == "gptq": + raise NotImplementedError("gptq is not implemented yet") else: raise ValueError(f"Unsupported model format: {model_format}") diff --git a/xinference/core/tests/test_supervisor.py b/xinference/core/tests/test_supervisor.py index f26976d10b..e5817ebc82 100644 --- a/xinference/core/tests/test_supervisor.py +++ b/xinference/core/tests/test_supervisor.py @@ -167,3 +167,59 @@ async def test_get_llm_spec_ms(): assert False except ValueError as e: assert str(e) == "Model Nobody/No_This_Repo does not exist" + + +@pytest.mark.asyncio +async def test_get_llm_spec_2(): + supervisor = SupervisorActor() + llm_family = await supervisor.get_llm_spec( + "Qwen/Qwen1.5-1.8B", "pytorch", "huggingface" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + pytorch_qs = {"4-bit, 8-bit, none"} + assert ( + pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs + ) + + assert llm_family.model_specs[0].model_size_in_billions == "1_8" + assert llm_family.context_length == 32768 + + llm_family = await supervisor.get_llm_spec( + "qwen/Qwen-14B-Chat", "pytorch", "modelscope" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + pytorch_qs = {"4-bit, 8-bit, none"} + assert ( + pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs + ) + + assert llm_family.model_specs[0].model_size_in_billions == 14 + assert llm_family.context_length == 8192 + + llm_family = await supervisor.get_llm_spec( + "qwen/Qwen1.5-7B-Chat-AWQ", "awq", "modelscope" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + pytorch_qs = {"Int4"} + assert ( + pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs + ) + + assert llm_family.model_specs[0].model_size_in_billions == 7 + assert llm_family.context_length == 32768 + + llm_family = await supervisor.get_llm_spec( + "casperhansen/mixtral-instruct-awq", "awq", "huggingface" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + pytorch_qs = {"Int4"} + assert ( + pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs + ) + + assert llm_family.model_specs[0].model_size_in_billions == 0 + assert llm_family.context_length == 32768 From 46e1626dfb24d62720cd0ab85641ebb51d0768c9 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 17:54:16 +0800 Subject: [PATCH 13/21] add support of pytorch and awq format to fetch model info from hub. --- xinference/core/supervisor.py | 8 +++++--- xinference/core/tests/test_supervisor.py | 24 ++++++++++++------------ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 2fc4a4388d..0874b59504 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -986,10 +986,10 @@ def record_metrics(name, op, kwargs): record_metrics(name, op, kwargs) @log_async(logger=logger) - async def get_llm_spec( + async def get_llm_family_from_hub( self, model_id: str, - model_format: Literal["pytorch", "ggmlv3", "ggufv2", "gptq", "awq"], + model_format: str, model_hub: str, ) -> HubImportLLMFamilyV1: if model_hub not in ["huggingface", "modelscope"]: @@ -1049,7 +1049,9 @@ async def get_llm_spec( model_hub=model_hub, model_size_in_billions=get_model_size_from_model_id(model_id), quantizations=( - ["4-bit, 8-bit, none"] if model_format == "pytorch" else ["Int4"] + ["4-bit", "8-bit", "none"] + if model_format == "pytorch" + else ["Int4"] ), ) return HubImportLLMFamilyV1( diff --git a/xinference/core/tests/test_supervisor.py b/xinference/core/tests/test_supervisor.py index e5817ebc82..e35910dd3a 100644 --- a/xinference/core/tests/test_supervisor.py +++ b/xinference/core/tests/test_supervisor.py @@ -6,7 +6,7 @@ @pytest.mark.asyncio async def test_get_llm_spec_hf(): supervisor = SupervisorActor() - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "TheBloke/Llama-2-7B-Chat-GGML", "ggmlv3", "huggingface" ) assert llm_family is not None @@ -41,7 +41,7 @@ async def test_get_llm_spec_hf(): llm_family.model_specs[0].quantizations ) - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "ggufv2", "huggingface" ) assert llm_family is not None @@ -80,7 +80,7 @@ async def test_get_llm_spec_hf(): }.intersection(set(qs)) == set(qs) try: - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "Nobody/No_This_Repo", "ggufv2", "huggingface" ) assert False @@ -91,7 +91,7 @@ async def test_get_llm_spec_hf(): @pytest.mark.asyncio async def test_get_llm_spec_ms(): supervisor = SupervisorActor() - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "Xorbits/Llama-2-7B-Chat-GGML", "ggmlv3", "modelscope" ) assert llm_family is not None @@ -126,7 +126,7 @@ async def test_get_llm_spec_ms(): llm_family.model_specs[0].quantizations ) - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "qwen/Qwen1.5-72B-Chat-GGUF", "ggufv2", "modelscope" ) assert llm_family is not None @@ -161,7 +161,7 @@ async def test_get_llm_spec_ms(): }.intersection(set(qs)) == set(qs) try: - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "Nobody/No_This_Repo", "ggufv2", "modelscope" ) assert False @@ -172,12 +172,12 @@ async def test_get_llm_spec_ms(): @pytest.mark.asyncio async def test_get_llm_spec_2(): supervisor = SupervisorActor() - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "Qwen/Qwen1.5-1.8B", "pytorch", "huggingface" ) assert llm_family is not None assert len(llm_family.model_specs) == 1 - pytorch_qs = {"4-bit, 8-bit, none"} + pytorch_qs = {"4-bit", "8-bit", "none"} assert ( pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs ) @@ -185,12 +185,12 @@ async def test_get_llm_spec_2(): assert llm_family.model_specs[0].model_size_in_billions == "1_8" assert llm_family.context_length == 32768 - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "qwen/Qwen-14B-Chat", "pytorch", "modelscope" ) assert llm_family is not None assert len(llm_family.model_specs) == 1 - pytorch_qs = {"4-bit, 8-bit, none"} + pytorch_qs = {"4-bit", "8-bit", "none"} assert ( pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs ) @@ -198,7 +198,7 @@ async def test_get_llm_spec_2(): assert llm_family.model_specs[0].model_size_in_billions == 14 assert llm_family.context_length == 8192 - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "qwen/Qwen1.5-7B-Chat-AWQ", "awq", "modelscope" ) assert llm_family is not None @@ -211,7 +211,7 @@ async def test_get_llm_spec_2(): assert llm_family.model_specs[0].model_size_in_billions == 7 assert llm_family.context_length == 32768 - llm_family = await supervisor.get_llm_spec( + llm_family = await supervisor.get_llm_family_from_hub( "casperhansen/mixtral-instruct-awq", "awq", "huggingface" ) assert llm_family is not None From 97dbae6b7282c8d7cfa7a3fd71a4403642804e25 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Mon, 18 Mar 2024 18:09:08 +0800 Subject: [PATCH 14/21] add the rest api for fetch model info from model hub --- xinference/api/restful_api.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 67cee3e4a2..77eab02e75 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -489,6 +489,16 @@ def serve(self, logging_conf: Optional[dict] = None): else None ), ) + self._router.add_api_route( + "/v1/model_registrations/{model_type}/{model_hub}/{model_format}/{user}/{repo}", + self.get_model_info_from_hub, + methods=["GET"], + dependencies=( + [Security(self._auth_service, scopes=["models:register"])] + if self.is_authenticated() + else None + ), + ) # Clear the global Registry for the MetricsMiddleware, or # the MetricsMiddleware will register duplicated metrics if the port @@ -1502,6 +1512,16 @@ async def get_cluster_version(self) -> JSONResponse: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + async def get_model_info_from_hub( + self, model_type: str, model_hub: str, model_format: str, user: str, repo: str + ) -> JSONResponse: + if model_type != "LLM": + raise HTTPException(status_code=400, detail="only LLM model type supported") + llm_family_info = await ( + await self._get_supervisor_ref() + ).get_llm_family_from_hub(f"{user}/{repo}", model_format, model_hub) + return JSONResponse(content=llm_family_info) + def run( supervisor_address: str, From b853970555cdd320b828b7484db36c34fcf8d56f Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Tue, 19 Mar 2024 10:13:11 +0800 Subject: [PATCH 15/21] refactor the pytest case with pytest.raises to catch exception. --- xinference/core/tests/test_supervisor.py | 14 +++------ xinference/core/tests/test_utils.py | 6 ++-- xinference/model/llm/tests/test_utils.py | 40 +++++------------------- 3 files changed, 14 insertions(+), 46 deletions(-) diff --git a/xinference/core/tests/test_supervisor.py b/xinference/core/tests/test_supervisor.py index e35910dd3a..bec69cfaf0 100644 --- a/xinference/core/tests/test_supervisor.py +++ b/xinference/core/tests/test_supervisor.py @@ -79,13 +79,10 @@ async def test_get_llm_spec_hf(): "Q8_0", }.intersection(set(qs)) == set(qs) - try: - llm_family = await supervisor.get_llm_family_from_hub( + with pytest.raises(ValueError, match="Model Nobody/No_This_Repo does not exist"): + await supervisor.get_llm_family_from_hub( "Nobody/No_This_Repo", "ggufv2", "huggingface" ) - assert False - except ValueError as e: - assert str(e) == "Model Nobody/No_This_Repo does not exist" @pytest.mark.asyncio @@ -160,13 +157,10 @@ async def test_get_llm_spec_ms(): "q8_0", }.intersection(set(qs)) == set(qs) - try: - llm_family = await supervisor.get_llm_family_from_hub( + with pytest.raises(ValueError, match="Model Nobody/No_This_Repo does not exist"): + await supervisor.get_llm_family_from_hub( "Nobody/No_This_Repo", "ggufv2", "modelscope" ) - assert False - except ValueError as e: - assert str(e) == "Model Nobody/No_This_Repo does not exist" @pytest.mark.asyncio diff --git a/xinference/core/tests/test_utils.py b/xinference/core/tests/test_utils.py index c2f740cc2f..82d324ab02 100644 --- a/xinference/core/tests/test_utils.py +++ b/xinference/core/tests/test_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest from ..utils import ( SUPPORTED_QUANTIZATIONS, @@ -98,11 +99,8 @@ def test_get_model_size_from_model_id(): assert model_size == 0 model_id = "abc" - try: + with pytest.raises(ValueError, match=r"Cannot parse model_id: .+"): get_model_size_from_model_id(model_id) - assert False - except ValueError: - pass def test_get_match_quantization_filenames(): diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index b3e37224db..d8daef217a 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -444,19 +444,13 @@ def test_repo_exists(model_hub_util): "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" ) assert not model_hub_util.repo_exists("Nobody/No_This_Repo", "huggingface") - try: + with pytest.raises(ValueError, match="Unsupported model hub"): model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") - assert False - except ValueError: - assert True assert model_hub_util.repo_exists("qwen/Qwen1.5-72B-Chat-GGUF", "modelscope") assert not model_hub_util.repo_exists("Nobody/No_This_Repo", "modelscope") - try: + with pytest.raises(ValueError, match="Unsupported model hub"): model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") - assert False - except ValueError: - assert True @pytest.mark.asyncio @@ -465,21 +459,15 @@ async def test_a_repo_exists(model_hub_util): "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" ) assert not await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "huggingface") - try: + with pytest.raises(ValueError, match="Unsupported model hub"): model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") - assert False - except ValueError: - assert True assert await model_hub_util.a_repo_exists( "qwen/Qwen1.5-72B-Chat-GGUF", "modelscope" ) assert not await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "modelscope") - try: + with pytest.raises(ValueError, match="Unsupported model hub"): await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "unknown_hub") - assert False - except ValueError: - assert True def test_get_config_path(model_hub_util): @@ -542,17 +530,11 @@ def test_list_repo_files(model_hub_util): ) assert len(files) == 12 # the `.gitattributes` file is not included - try: + with pytest.raises(ValueError, match="Repository Nobody/No_This_Repo not found."): model_hub_util.list_repo_files("Nobody/No_This_Repo", "huggingface") - assert False - except ValueError as e: - assert str(e) == "Repository Nobody/No_This_Repo not found." - try: + with pytest.raises(ValueError, match="Repository Nobody/No_This_Repo not found."): model_hub_util.list_repo_files("Nobody/No_This_Repo", "modelscope") - assert False - except ValueError as e: - assert str(e) == "Repository Nobody/No_This_Repo not found." @pytest.mark.asyncio @@ -567,14 +549,8 @@ async def test_a_list_repo_files(model_hub_util): ) assert len(files) == 12 # the `.gitattributes` file is not included - try: + with pytest.raises(ValueError, match="Repository Nobody/No_This_Repo not found."): await model_hub_util.a_list_repo_files("Nobody/No_This_Repo", "huggingface") - assert False - except ValueError as e: - assert str(e) == "Repository Nobody/No_This_Repo not found." - try: + with pytest.raises(ValueError, match="Repository Nobody/No_This_Repo not found."): await model_hub_util.a_list_repo_files("Nobody/No_This_Repo", "modelscope") - assert False - except ValueError as e: - assert str(e) == "Repository Nobody/No_This_Repo not found." From b354cec5059f5ab51ce14951ce6e1348d028736a Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Tue, 19 Mar 2024 16:34:34 +0800 Subject: [PATCH 16/21] fix the but that '0.x' form model id cannot be processed correctly --- xinference/core/tests/test_utils.py | 4 ++++ xinference/core/utils.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/xinference/core/tests/test_utils.py b/xinference/core/tests/test_utils.py index 82d324ab02..c225aa12f5 100644 --- a/xinference/core/tests/test_utils.py +++ b/xinference/core/tests/test_utils.py @@ -90,6 +90,10 @@ def test_get_model_size_from_model_id(): model_size = get_model_size_from_model_id(model_id) assert model_size == 33 + model_id = "qwen/Qwen1.5-0.5B-Chat" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "0_5" + model_id = "mlx-community/c4ai-command-r-v01-4bit" model_size = get_model_size_from_model_id(model_id) assert model_size == 0 diff --git a/xinference/core/utils.py b/xinference/core/utils.py index 5cf41e332a..e146931d86 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -218,7 +218,7 @@ def resize_to_billion(size: str) -> Union[str, int, float]: size = size[:-1] if "_" not in size: - if size[0] == "0": + if size[0] == "0" and "." not in size: size = size[0] + "." + str(size[1:]) if "." in size: From e5ba02d60827a05d15b5ab4ccaf5f702e39af714 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Tue, 19 Mar 2024 16:41:24 +0800 Subject: [PATCH 17/21] update caniuse-lite --- xinference/web/ui/package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xinference/web/ui/package-lock.json b/xinference/web/ui/package-lock.json index 4ae5037245..ab3397754e 100644 --- a/xinference/web/ui/package-lock.json +++ b/xinference/web/ui/package-lock.json @@ -6924,9 +6924,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001515", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001515.tgz", - "integrity": "sha512-eEFDwUOZbE24sb+Ecsx3+OvNETqjWIdabMy52oOkIgcUtAsQifjUG9q4U9dgTHJM2mfk4uEPxc0+xuFdJ629QA==", + "version": "1.0.30001599", + "resolved": "https://mirrors.cloud.tencent.com/npm/caniuse-lite/-/caniuse-lite-1.0.30001599.tgz", + "integrity": "sha512-LRAQHZ4yT1+f9LemSMeqdMpMxZcc4RMWdj4tiFe3G8tNkWK+E58g+/tzotb5cU6TbcVJLr4fySiAW7XmxQvZQA==", "funding": [ { "type": "opencollective", From 9b774febba68440cc971042a880331a6a66a66cf Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Tue, 19 Mar 2024 17:30:11 +0800 Subject: [PATCH 18/21] Finish LLM custom model register from hub. --- xinference/api/restful_api.py | 19 +- xinference/model/llm/llm_family.py | 5 +- .../register_model/register_language.js | 241 ++++++++++++++---- 3 files changed, 215 insertions(+), 50 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 77eab02e75..41849312aa 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -1515,11 +1515,20 @@ async def get_cluster_version(self) -> JSONResponse: async def get_model_info_from_hub( self, model_type: str, model_hub: str, model_format: str, user: str, repo: str ) -> JSONResponse: - if model_type != "LLM": - raise HTTPException(status_code=400, detail="only LLM model type supported") - llm_family_info = await ( - await self._get_supervisor_ref() - ).get_llm_family_from_hub(f"{user}/{repo}", model_format, model_hub) + try: + if model_type != "LLM": + raise HTTPException( + status_code=400, detail="only LLM model type supported currently" + ) + llm_family_info = await ( + await self._get_supervisor_ref() + ).get_llm_family_from_hub(f"{user}/{repo}", model_format, model_hub) + except ValueError as re: + logger.error(re, exc_info=True) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) return JSONResponse(content=llm_family_info) diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index d0e744fea0..da34566ed7 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -542,7 +542,10 @@ def _generate_model_file_names( ) need_merge = False - if llm_spec.quantization_parts is None: + if ( + llm_spec.quantization_parts is None + or quantization not in llm_spec.quantization_parts + ): file_names.append(final_file_name) elif quantization is not None and quantization in llm_spec.quantization_parts: parts = llm_spec.quantization_parts[quantization] diff --git a/xinference/web/ui/src/scenes/register_model/register_language.js b/xinference/web/ui/src/scenes/register_model/register_language.js index 606fcecee2..2ac956af79 100644 --- a/xinference/web/ui/src/scenes/register_model/register_language.js +++ b/xinference/web/ui/src/scenes/register_model/register_language.js @@ -39,10 +39,13 @@ const RegisterLanguageModel = () => { const { setErrorMsg } = useContext(ApiContext) const [successMsg, setSuccessMsg] = useState('') const [modelFormat, setModelFormat] = useState('pytorch') + const [modelFileNameTemplate, setModelFileNameTemplate] = useState('') + const [modelFileNameSplitTemplate, setModelFileNameSplitTemplate] = useState('') const [modelSize, setModelSize] = useState(7) const [modelUri, setModelUri] = useState('/path/to/llama-2') const [modelId, setModelId] = useState('') const [quantization, setQuantization] = useState('') + const [quantizationParts, setQuantizationParts] = useState('') const [modelSource, setModelSource] = useState(SOURCES[0]) const [hub, setHub] = useState(SUPPORTED_HUBS[0]) const [formData, setFormData] = useState({ @@ -80,6 +83,11 @@ const RegisterLanguageModel = () => { ) }) const errorFamily = familyLabel === '' + const errorModelId = modelSource === 'hub' && modelId.search('\\w+/\\w+') === -1 + const errorModelFileNameTemplate = modelSource === 'hub' && ['ggufv2', 'ggmlv3'].includes(modelFormat) && + modelFileNameTemplate.trim().length <= 0 + const errorQuantizationParts = modelSource === 'hub' && ['ggufv2', 'ggmlv3'].includes(modelFormat) && + modelFileNameSplitTemplate.trim().length > 0 && quantizationParts.trim().length <= 0 const errorAny = errorModelName || errorModelDescription || @@ -87,7 +95,10 @@ const RegisterLanguageModel = () => { errorLanguage || errorAbility || errorModelSize || - errorFamily + errorFamily || + errorModelId || + errorModelFileNameTemplate || + errorQuantizationParts useEffect(() => { if (cookie.token === '' || cookie.token === undefined) { @@ -192,48 +203,93 @@ const RegisterLanguageModel = () => { } const handleClick = async () => { - if (isModelFormatGPTQ()) { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_uri: modelUri, - }, - ] - } else if (isModelFormatAWQ()) { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_uri: modelUri, - }, - ] - } else if (!isModelFormatPytorch()) { - const { baseDir, filename } = getPathComponents(modelUri) - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_file_name_template: filename, - model_uri: baseDir, - }, - ] - } else { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: ['4-bit', '8-bit', 'none'], - model_id: '', - model_uri: modelUri, - }, - ] + if (modelSource === 'self_hosted') { + if (isModelFormatGPTQ()) { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_uri: modelUri, + }, + ] + } else if (isModelFormatAWQ()) { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_uri: modelUri, + }, + ] + } else if (!isModelFormatPytorch()) { + const { baseDir, filename } = getPathComponents(modelUri) + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_file_name_template: filename, + model_uri: baseDir, + }, + ] + } else { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: ['4-bit', '8-bit', 'none'], + model_id: '', + model_uri: modelUri, + }, + ] + } + } else if (modelSource === 'hub') { + const quantization_array = quantization.split(',') + if (isModelFormatGPTQ() || isModelFormatAWQ()) { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: quantization_array, + model_hub: hub, + model_id: modelId, + model_uri: null, + }, + ] + } else if (!isModelFormatPytorch()) { + const qParts = quantizationParts.length > 0 ? JSON.parse(quantizationParts) : null + let splitTemplate = modelFileNameSplitTemplate.trim() + splitTemplate = splitTemplate.length > 0 ? splitTemplate : null + + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + model_file_name_template: modelFileNameTemplate, + model_file_name_split_template: splitTemplate, + quantizations: quantization_array, + quantization_parts: qParts, + model_hub: hub, + model_id: modelId, + model_uri: null, + }, + ] + } else { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: ['4-bit', '8-bit', 'none'], + model_hub: hub, + model_id: modelId, + model_uri: null, + }, + ] + } } formData.model_family = familyLabel @@ -288,7 +344,63 @@ const RegisterLanguageModel = () => { } const handleImportModel = async () => { - console.log('import model') + if (errorModelId) { + setErrorMsg('Please fill in valid value for Model Id') + return + } + const response = await fetcher(endPoint + + `/v1/model_registrations/LLM/${hub}/${modelFormat}/${modelId}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${errorData.detail || 'Unknown error'}`, + ) + } else { + const body = await response.json() + console.log('response', body) + if ('context_length' in body && body['context_length'] > 0) { + setFormData({ + ...formData, + context_length: Number(body['context_length']), + }) + } + + /** + * @type {object[]} + */ + const modelSpecs = body['model_specs'] + if (modelSpecs.length === 0) { + return + } + const modelSpec = modelSpecs[0] + + const modelSize = modelSpec['model_size_in_billions'] + setModelSize(modelSize) + + if (['ggufv2', 'ggmlv3'].includes(modelFormat)) { + + const modelFileNameTemplate = modelSpec['model_file_name_template'] + setModelFileNameTemplate(modelFileNameTemplate) + + const quantizations = modelSpec['quantizations'] + setQuantization(quantizations.join(',')) + + /** + * @type {string | null} + */ + const modelFileNameSplitTemplate = modelSpec['model_file_name_split_template'] + if (modelFileNameSplitTemplate !== null && modelFileNameSplitTemplate.trim() !== '') { + setModelFileNameSplitTemplate(modelFileNameSplitTemplate) + const parts = JSON.stringify(modelSpec['quantization_parts']) + setQuantizationParts(parts) + } + } + } } const toggleLanguage = (lang) => { @@ -437,6 +549,7 @@ const RegisterLanguageModel = () => { sx={{ width: '400px' }} label="Model Id" size="small" + error={errorModelId} value={modelId} onChange={(e) => { setModelId(e.target.value) @@ -517,6 +630,29 @@ const RegisterLanguageModel = () => { /> <Box padding="15px"></Box> + {modelSource === 'hub' && ['ggufv2', 'ggmlv3'].includes(modelFormat) && + <> + <TextField + label="Model File Name Template" + size="small" + value={modelFileNameTemplate} + onChange={(e) => { + setModelFileNameTemplate(e.target.value) + }} + error={errorModelFileNameTemplate} + /> + <Box padding="15px"></Box> + <TextField + label="Model File Name Split Template (Optional)" + size="small" + value={modelFileNameSplitTemplate} + onChange={(e) => { + setModelFileNameSplitTemplate(e.target.value) + }} + /> + <Box padding="15px"></Box> + </> + } <TextField label="Quantization (Optional)" @@ -529,6 +665,23 @@ const RegisterLanguageModel = () => { /> <Box padding="15px"></Box> + {modelSource === 'hub' && ['ggufv2', 'ggmlv3'].includes(modelFormat) && + modelFileNameSplitTemplate.trim().length > 0 && + <> + <TextField + label="Quantization Parts (Optional)" + size="small" + value={quantizationParts} + error={errorQuantizationParts} + onChange={(e) => { + setQuantizationParts(e.target.value.trim()) + }} + helperText="If there is more than 1 quantization parts, separated by commas" + /> + <Box padding="15px"></Box> + </> + } + <TextField label="Model Description (Optional)" error={errorModelDescription} From c0a636c489b54acbe8f7e8681b59a15bd803ceab Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Tue, 2 Apr 2024 17:10:01 +0800 Subject: [PATCH 19/21] Finish embedding custom model register from hub. --- xinference/api/restful_api.py | 23 ++- xinference/core/supervisor.py | 40 ++++ xinference/core/tests/test_supervisor.py | 21 ++ xinference/model/embedding/custom.py | 3 +- .../model/embedding/tests/test_utils.py | 30 +++ xinference/model/embedding/utils.py | 16 ++ .../register_model/register_embedding.js | 192 +++++++++++++++--- 7 files changed, 291 insertions(+), 34 deletions(-) create mode 100644 xinference/model/embedding/tests/test_utils.py diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index c98cad4888..454f40780b 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -1521,20 +1521,27 @@ async def get_model_info_from_hub( self, model_type: str, model_hub: str, model_format: str, user: str, repo: str ) -> JSONResponse: try: - if model_type != "LLM": - raise HTTPException( - status_code=400, detail="only LLM model type supported currently" - ) - llm_family_info = await ( - await self._get_supervisor_ref() - ).get_llm_family_from_hub(f"{user}/{repo}", model_format, model_hub) + if model_type == "LLM": + llm_family_info = await ( + await self._get_supervisor_ref() + ).get_llm_family_from_hub(f"{user}/{repo}", model_format, model_hub) + return JSONResponse(content=llm_family_info) + if model_type == "embedding": + embed_spec = await ( + await self._get_supervisor_ref() + ).get_embedding_spec_from_hub(f"{user}/{repo}", model_hub) + return JSONResponse(content=embed_spec) except ValueError as re: logger.error(re, exc_info=True) raise HTTPException(status_code=400, detail=str(re)) except Exception as e: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) - return JSONResponse(content=llm_family_info) + + raise HTTPException( + status_code=400, + detail="only LLM and embedding model type supported currently", + ) def run( diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index f1a56a0a9b..e7d3bbc5ac 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -32,6 +32,8 @@ ) from ..core import ModelActor from ..core.status_guard import InstanceInfo, LaunchStatus +from ..model.embedding import CustomEmbeddingModelSpec +from ..model.embedding.utils import get_language_from_model_id from ..model.llm import GgmlLLMSpecV1 from ..model.llm.llm_family import ( DEFAULT_CONTEXT_LENGTH, @@ -62,6 +64,7 @@ from ..model.rerank import RerankModelSpec from .worker import WorkerActor + logger = getLogger(__name__) ASYNC_LAUNCH_TASKS = {} # type: ignore @@ -1092,3 +1095,40 @@ async def get_llm_family_from_hub( raise NotImplementedError("gptq is not implemented yet") else: raise ValueError(f"Unsupported model format: {model_format}") + + @log_async(logger=logger) + async def get_embedding_spec_from_hub( + self, model_id: str, model_hub: str + ) -> CustomEmbeddingModelSpec: + if model_hub not in ["huggingface", "modelscope"]: + raise ValueError(f"Unsupported model hub: {model_hub}") + + model_hub = cast(MODEL_HUB, model_hub) + + repo_exists = await self._model_hub_util.a_repo_exists( + model_id, + model_hub, + ) + + if not repo_exists: + raise ValueError(f"Model {model_id} does not exist") + + max_tokens = 512 + dimensions = 768 + if config_path := await self._model_hub_util.a_get_config_path( + model_id, model_hub + ): + with open(config_path) as f: + config = json.load(f) + if "max_position_embeddings" in config: + max_tokens = config["max_position_embeddings"] + if "hidden_size" in config: + dimensions = config["hidden_size"] + return CustomEmbeddingModelSpec( + model_name=model_id.split("/")[-1], + model_id=model_id, + max_tokens=max_tokens, + dimensions=dimensions, + model_hub=model_hub, + language=[get_language_from_model_id(model_id)], + ) diff --git a/xinference/core/tests/test_supervisor.py b/xinference/core/tests/test_supervisor.py index bec69cfaf0..09001fe07a 100644 --- a/xinference/core/tests/test_supervisor.py +++ b/xinference/core/tests/test_supervisor.py @@ -217,3 +217,24 @@ async def test_get_llm_spec_2(): assert llm_family.model_specs[0].model_size_in_billions == 0 assert llm_family.context_length == 32768 + + +@pytest.mark.asyncio +async def test_get_embedding_spec_from_hub(): + supervisor = SupervisorActor() + embedding_spec = await supervisor.get_embedding_spec_from_hub( + "BAAI/bge-large-zh-v1.5", "huggingface" + ) + assert embedding_spec is not None + assert embedding_spec.model_name == "bge-large-zh-v1.5" + assert embedding_spec.model_id == "BAAI/bge-large-zh-v1.5" + + embedding_spec = await supervisor.get_embedding_spec_from_hub( + "bensonpeng/bge-large-en-v1.5", "modelscope" + ) + + assert embedding_spec is not None + assert embedding_spec.model_name == "bge-large-en-v1.5" + assert embedding_spec.model_id == "bensonpeng/bge-large-en-v1.5" + assert embedding_spec.max_tokens == 512 + assert embedding_spec.dimensions == 1024 diff --git a/xinference/model/embedding/custom.py b/xinference/model/embedding/custom.py index 8e311bbd7d..83ce51bfdb 100644 --- a/xinference/model/embedding/custom.py +++ b/xinference/model/embedding/custom.py @@ -63,7 +63,8 @@ def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool): if persist: # We only validate model URL when persist is True. model_uri = model_spec.model_uri - if model_uri and not is_valid_model_uri(model_uri): + model_id = model_spec.model_id + if model_id is None and model_uri and not is_valid_model_uri(model_uri): raise ValueError(f"Invalid model URI {model_uri}.") persist_path = os.path.join( diff --git a/xinference/model/embedding/tests/test_utils.py b/xinference/model/embedding/tests/test_utils.py new file mode 100644 index 0000000000..c5d2bead19 --- /dev/null +++ b/xinference/model/embedding/tests/test_utils.py @@ -0,0 +1,30 @@ +from ..utils import get_language_from_model_id + + +def test_get_language_from_model_id(): + model_id = "BAAI/bge-large-zh-v1.5" + assert get_language_from_model_id(model_id) == "zh" + + model_id = "BAAI/bge-large-base-v1.5" + assert get_language_from_model_id(model_id) == "en" + + model_id = "google-bert/bert-base-multilingual-cased" + assert get_language_from_model_id(model_id) == "zh" + + model_id = "jinaai/jina-embeddings-v2-base-en" + assert get_language_from_model_id(model_id) == "en" + + model_id = "jinaai/jina-embeddings-v2-base-es" + # now only support zh and en, if it is not chinese, then en, even the language is specified as es + assert get_language_from_model_id(model_id) == "en" + + model_id = "bge-large-zh-v1.5" + # wrong model id will cause the en is returned + assert get_language_from_model_id(model_id) == "en" + + model_id = "BAAI/newtype/bge-large-zh-v1.5" + # wrong model id format, return en + assert get_language_from_model_id(model_id) == "en" + + model_id = "" + assert get_language_from_model_id(model_id) == "en" diff --git a/xinference/model/embedding/utils.py b/xinference/model/embedding/utils.py index 8b63e6eb5f..547ec33171 100644 --- a/xinference/model/embedding/utils.py +++ b/xinference/model/embedding/utils.py @@ -11,8 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from logging import getLogger + from .core import EmbeddingModelSpec def get_model_version(embedding_model: EmbeddingModelSpec) -> str: return f"{embedding_model.model_name}--{embedding_model.max_tokens}--{embedding_model.dimensions}" + + +def get_language_from_model_id(model_id: str) -> str: + split = model_id.split("/") + if len(split) != 2: + logger = getLogger(__name__) + logger.error(f"Invalid model_id: {model_id}, return the default en language") + return "en" + model_id = split[-1] + segments = model_id.split("-") + for seg in segments: + if seg.lower() in ["zh", "cn", "chinese", "multilingual"]: + return "zh" + return "en" diff --git a/xinference/web/ui/src/scenes/register_model/register_embedding.js b/xinference/web/ui/src/scenes/register_model/register_embedding.js index ac7ab8d4ae..7faca81731 100644 --- a/xinference/web/ui/src/scenes/register_model/register_embedding.js +++ b/xinference/web/ui/src/scenes/register_model/register_embedding.js @@ -1,4 +1,14 @@ -import { Box, Checkbox, FormControl, FormControlLabel } from '@mui/material' +import { + Box, + Checkbox, + FormControl, + FormControlLabel, + InputLabel, + MenuItem, + Radio, + RadioGroup, + Select, +} from '@mui/material' import Alert from '@mui/material/Alert' import AlertTitle from '@mui/material/AlertTitle' import Button from '@mui/material/Button' @@ -13,20 +23,33 @@ const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } // Convert dictionary of supported languages into list const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) + +const SUPPORTED_HUBS_DICT = { huggingface: 'HuggingFace', modelscope: 'ModelScope' } +const SUPPORTED_HUBS = Object.keys(SUPPORTED_HUBS_DICT) + +const SOURCES_DICT = { self_hosted: 'Self Hosted', hub: 'Hub' } +const SOURCES = Object.keys(SOURCES_DICT) + const RegisterEmbeddingModel = () => { const ERROR_COLOR = useMode() const endPoint = useContext(ApiContext).endPoint const { setErrorMsg } = useContext(ApiContext) const [successMsg, setSuccessMsg] = useState('') + const [modelSource, setModelSource] = useState(SOURCES[0]) + const [hub, setHub] = useState(SUPPORTED_HUBS[0]) + const [modelId, setModelId] = useState('') const [formData, setFormData] = useState({ model_name: 'custom-embedding', dimensions: 768, max_tokens: 512, language: ['en'], model_uri: '/path/to/embedding-model', + model_id: null, + model_hub: null, }) const errorModelName = formData.model_name.trim().length <= 0 + const errorModelId = modelSource === 'hub' && modelId.search('\\w+/\\w+') === -1 const errorDimensions = formData.dimensions < 0 const errorMaxTokens = formData.max_tokens < 0 const errorLanguage = @@ -34,13 +57,27 @@ const RegisterEmbeddingModel = () => { const handleClick = async () => { const errorAny = - errorModelName || errorDimensions || errorMaxTokens || errorLanguage + errorModelName || errorDimensions || errorMaxTokens || errorLanguage || errorModelId if (errorAny) { setErrorMsg('Please fill in valid value for all fields') return } + let myFormData = {} + if (modelSource === 'self_hosted') { + myFormData = { + ...formData, + model_hub: null, + model_id: null, + } + } else { + myFormData = { + ...formData, + model_uri: null, + } + } + console.log(myFormData) try { const response = await fetcher( endPoint + '/v1/model_registrations/embedding', @@ -50,21 +87,21 @@ const RegisterEmbeddingModel = () => { 'Content-Type': 'application/json', }, body: JSON.stringify({ - model: JSON.stringify(formData), + model: JSON.stringify(myFormData), persist: true, }), - } + }, ) if (!response.ok) { const errorData = await response.json() // Assuming the server returns error details in JSON format setErrorMsg( `Server error: ${response.status} - ${ errorData.detail || 'Unknown error' - }` + }`, ) } else { setSuccessMsg( - 'Model has been registered successfully! Navigate to launch model page to proceed.' + 'Model has been registered successfully! Navigate to launch model page to proceed.', ) } } catch (error) { @@ -87,6 +124,40 @@ const RegisterEmbeddingModel = () => { } } + const handleImportModel = async () => { + if (errorModelId) { + setErrorMsg('Please fill in valid value for Model Id') + return + } + const response = await fetcher(endPoint + + `/v1/model_registrations/embedding/${hub}/_/${modelId}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }) + + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }`, + ) + } else { + const data = await response.json() + setFormData({ + ...formData, + dimensions: data.dimensions, + max_tokens: data.max_tokens, + language: data.language, + model_hub: hub, + model_id: modelId, + }) + + } + } + return ( <React.Fragment> <Box padding="20px"></Box> @@ -104,6 +175,91 @@ const RegisterEmbeddingModel = () => { /> <Box padding="15px"></Box> + <label + style={{ + paddingLeft: 5, + }} + > + Model Source + </label> + + <RadioGroup + value={modelSource} + onChange={(e) => { + setModelSource(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + {SOURCES.map((item) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value={item} + control={<Radio />} + label={SOURCES_DICT[item]} + /> + </Box> + ))} + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + {modelSource === 'self_hosted' && + <TextField + label="Model Path" + size="small" + value={formData.model_uri} + onChange={(e) => { + setFormData({ + ...formData, + model_uri: e.target.value, + }) + }} + helperText="Provide the model directory path." + />} + + {modelSource === 'hub' && + <Box sx={styles.checkboxWrapper}> + + <TextField + sx={{ width: '400px' }} + label="Model Id" + size="small" + error={errorModelId} + value={modelId} + onChange={(e) => { + setModelId(e.target.value) + }} + placeholder="user/repo" + /> + + <FormControl variant="standard" + sx={{ marginLeft: '10px' }}> + <InputLabel id="hub-label">Hub</InputLabel> + <Select + labelId="hub-label" + value={hub} + label="Hub" + onChange={(e) => { + setHub(e.target.value) + }} + > + {SUPPORTED_HUBS.map((item) => ( + <MenuItem value={item}>{SUPPORTED_HUBS_DICT[item]}</MenuItem> + ))} + </Select> + </FormControl> + <Button + sx={{ marginLeft: '10px' }} + variant="contained" + color="primary" + onClick={handleImportModel} + > + Import Model + </Button> + </Box> + } + <Box padding="15px"></Box> + <TextField error={errorDimensions} label="Dimensions" @@ -132,20 +288,6 @@ const RegisterEmbeddingModel = () => { /> <Box padding="15px"></Box> - <TextField - label="Model Path" - size="small" - value={formData.model_uri} - onChange={(e) => { - setFormData({ - ...formData, - model_uri: e.target.value, - }) - }} - helperText="Provide the model directory path." - /> - <Box padding="15px"></Box> - <label style={{ paddingLeft: 5, @@ -166,11 +308,11 @@ const RegisterEmbeddingModel = () => { sx={ errorLanguage ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } : {} } /> From be6e780c0d27c65e4364a79920cbb92acc0703c5 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Tue, 2 Apr 2024 18:04:53 +0800 Subject: [PATCH 20/21] Finish rerank custom model register from hub. --- .../register_model/register_embedding.js | 2 +- .../scenes/register_model/register_rerank.js | 178 +++++++++++++++--- 2 files changed, 158 insertions(+), 22 deletions(-) diff --git a/xinference/web/ui/src/scenes/register_model/register_embedding.js b/xinference/web/ui/src/scenes/register_model/register_embedding.js index 7faca81731..87531813ed 100644 --- a/xinference/web/ui/src/scenes/register_model/register_embedding.js +++ b/xinference/web/ui/src/scenes/register_model/register_embedding.js @@ -64,7 +64,7 @@ const RegisterEmbeddingModel = () => { return } - let myFormData = {} + let myFormData if (modelSource === 'self_hosted') { myFormData = { ...formData, diff --git a/xinference/web/ui/src/scenes/register_model/register_rerank.js b/xinference/web/ui/src/scenes/register_model/register_rerank.js index 075b35ff9d..a2d9b2ad00 100644 --- a/xinference/web/ui/src/scenes/register_model/register_rerank.js +++ b/xinference/web/ui/src/scenes/register_model/register_rerank.js @@ -1,4 +1,14 @@ -import { Box, Checkbox, FormControl, FormControlLabel } from '@mui/material' +import { + Box, + Checkbox, + FormControl, + FormControlLabel, + InputLabel, + MenuItem, + Radio, + RadioGroup, + Select, +} from '@mui/material' import Alert from '@mui/material/Alert' import AlertTitle from '@mui/material/AlertTitle' import Button from '@mui/material/Button' @@ -12,24 +22,36 @@ import { useMode } from '../../theme' const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } // Convert dictionary of supported languages into list const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) +const SUPPORTED_HUBS_DICT = { huggingface: 'HuggingFace', modelscope: 'ModelScope' } +const SUPPORTED_HUBS = Object.keys(SUPPORTED_HUBS_DICT) + +const SOURCES_DICT = { self_hosted: 'Self Hosted', hub: 'Hub' } +const SOURCES = Object.keys(SOURCES_DICT) const RegisterRerankModel = () => { const ERROR_COLOR = useMode() const endPoint = useContext(ApiContext).endPoint const { setErrorMsg } = useContext(ApiContext) const [successMsg, setSuccessMsg] = useState('') + const [modelSource, setModelSource] = useState(SOURCES[0]) + const [hub, setHub] = useState(SUPPORTED_HUBS[0]) + const [modelId, setModelId] = useState('') + const [formData, setFormData] = useState({ model_name: 'custom-rerank', language: ['en'], model_uri: '/path/to/rerank-model', + model_id: null, + model_hub: null, }) const errorModelName = formData.model_name.trim().length <= 0 + const errorModelId = modelSource === 'hub' && modelId.search('\\w+/\\w+') === -1 const errorLanguage = formData.language === undefined || formData.language.length === 0 const handleClick = async () => { - const errorAny = errorModelName || errorLanguage + const errorAny = errorModelName || errorLanguage || errorModelId if (errorAny) { setErrorMsg('Please fill in valid value for all fields') @@ -37,6 +59,21 @@ const RegisterRerankModel = () => { } try { + let myFormData + if (modelSource === 'hub') { + myFormData = { + ...formData, + model_id: modelId, + model_hub: hub, + model_uri: null, + } + } else { + myFormData = { + ...formData, + model_id: null, + model_hub: null, + } + } const response = await fetcher( endPoint + '/v1/model_registrations/rerank', { @@ -45,21 +82,21 @@ const RegisterRerankModel = () => { 'Content-Type': 'application/json', }, body: JSON.stringify({ - model: JSON.stringify(formData), + model: JSON.stringify(myFormData), persist: true, }), - } + }, ) if (!response.ok) { const errorData = await response.json() // Assuming the server returns error details in JSON format setErrorMsg( `Server error: ${response.status} - ${ errorData.detail || 'Unknown error' - }` + }`, ) } else { setSuccessMsg( - 'Model has been registered successfully! Navigate to launch model page to proceed.' + 'Model has been registered successfully! Navigate to launch model page to proceed.', ) } } catch (error) { @@ -82,6 +119,33 @@ const RegisterRerankModel = () => { } } + const handleImportModel = async () => { + if (errorModelId) { + setErrorMsg('Please fill in valid value for Model Id') + return + } + + const split = modelId.split('/') + if (split.length !== 2) { + setErrorMsg('Please fill in valid value for Model Id') + return + } + + const repo_name = split[1] + const repo_split = repo_name.split(/[-_]/) + let lang = 'en' + for (const seg of repo_split) { + if (['zh', 'cn', 'chinese'].includes(seg.toLowerCase())) { + lang = 'zh' + break + } + } + setFormData({ + ...formData, + language: [lang], + }) + } + return ( <React.Fragment> <Box padding="20px"></Box> @@ -99,18 +163,90 @@ const RegisterRerankModel = () => { /> <Box padding="15px"></Box> - <TextField - label="Model Path" - size="small" - value={formData.model_uri} + + <label + style={{ + paddingLeft: 5, + }} + > + Model Source + </label> + + <RadioGroup + value={modelSource} onChange={(e) => { - setFormData({ - ...formData, - model_uri: e.target.value, - }) + setModelSource(e.target.value) }} - helperText="Provide the model directory path." - /> + > + <Box sx={styles.checkboxWrapper}> + {SOURCES.map((item) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value={item} + control={<Radio />} + label={SOURCES_DICT[item]} + /> + </Box> + ))} + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + + {modelSource === 'self_hosted' && + <TextField + label="Model Path" + size="small" + value={formData.model_uri} + onChange={(e) => { + setFormData({ + ...formData, + model_uri: e.target.value, + }) + }} + helperText="Provide the model directory path." + />} + {modelSource === 'hub' && + <Box sx={styles.checkboxWrapper}> + + <TextField + sx={{ width: '400px' }} + label="Model Id" + size="small" + error={errorModelId} + value={modelId} + onChange={(e) => { + setModelId(e.target.value) + }} + placeholder="user/repo" + /> + + <FormControl variant="standard" + sx={{ marginLeft: '10px' }}> + <InputLabel id="hub-label">Hub</InputLabel> + <Select + labelId="hub-label" + value={hub} + label="Hub" + onChange={(e) => { + setHub(e.target.value) + }} + > + {SUPPORTED_HUBS.map((item) => ( + <MenuItem value={item}>{SUPPORTED_HUBS_DICT[item]}</MenuItem> + ))} + </Select> + </FormControl> + <Button + sx={{ marginLeft: '10px' }} + variant="contained" + color="primary" + onClick={handleImportModel} + > + Import Model + </Button> + </Box> + } <Box padding="15px"></Box> <label @@ -133,11 +269,11 @@ const RegisterRerankModel = () => { sx={ errorLanguage ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } : {} } /> From dc8eb985621c6126f5b6162f035efb69e3437130 Mon Sep 17 00:00:00 2001 From: Shi Hui <shihui@hyron.com> Date: Wed, 3 Apr 2024 14:29:13 +0800 Subject: [PATCH 21/21] fix the bug that Future with generic type will cause error when it is run on python < 3.9 --- xinference/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xinference/utils.py b/xinference/utils.py index dcc2f66b3c..bc36f01a5d 100644 --- a/xinference/utils.py +++ b/xinference/utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import asyncio from concurrent.futures import Future from concurrent.futures.thread import ThreadPoolExecutor