Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 154 additions & 40 deletions examples/pipelines/providers/azure_openai_manifold_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self):
}
)
self.set_pipelines()
pass

def set_pipelines(self):
models = self.valves.AZURE_OPENAI_MODELS.split(";")
Expand All @@ -35,27 +34,26 @@ def set_pipelines(self):
{"id": model, "name": name} for model, name in zip(models, model_names)
]
print(f"azure_openai_manifold_pipeline - models: {self.pipelines}")
pass

async def on_valves_updated(self):
self.set_pipelines()
self.set_pipelines()

async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
pass

async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass

def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.
self,
user_message: str,
model_id: str,
messages: List[dict],
body: dict
) -> Union[str, Generator[str, None, None], Iterator[str]]:
print(f"pipe:{__name__}")

print(messages)
print(user_message)

Expand All @@ -64,36 +62,152 @@ def pipe(
"Content-Type": "application/json",
}

url = f"{self.valves.AZURE_OPENAI_ENDPOINT}/openai/deployments/{model_id}/chat/completions?api-version={self.valves.AZURE_OPENAI_API_VERSION}"
# URL for Chat Completions in Azure OpenAI
url = (
f"{self.valves.AZURE_OPENAI_ENDPOINT}/openai/deployments/"
f"{model_id}/chat/completions?api-version={self.valves.AZURE_OPENAI_API_VERSION}"
)

# --- Define the allowed parameter sets ---
# (1) Default allowed params (non-o1)
allowed_params_default = {
"messages",
"temperature",
"role",
"content",
"contentPart",
"contentPartImage",
"enhancements",
"dataSources",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"function_call",
"funcions",
"tools",
"tool_choice",
"top_p",
"log_probs",
"top_logprobs",
"response_format",
"seed",
}

allowed_params = {'messages', 'temperature', 'role', 'content', 'contentPart', 'contentPartImage',
'enhancements', 'dataSources', 'n', 'stream', 'stop', 'max_tokens', 'presence_penalty',
'frequency_penalty', 'logit_bias', 'user', 'function_call', 'funcions', 'tools',
'tool_choice', 'top_p', 'log_probs', 'top_logprobs', 'response_format', 'seed'}
# remap user field
# (2) o1 models allowed params
allowed_params_o1 = {
"model",
"messages",
"top_p",
"n",
"max_completion_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
}

# Simple helper to detect if it's an o1 model
def is_o1_model(m: str) -> bool:
# Adjust this check to your naming pattern for o1 models
return "o1" in m or m.startswith("o")

# Ensure user is a string
if "user" in body and not isinstance(body["user"], str):
body["user"] = body["user"]["id"] if "id" in body["user"] else str(body["user"])
filtered_body = {k: v for k, v in body.items() if k in allowed_params}
# log fields that were filtered out as a single line
if len(body) != len(filtered_body):
print(f"Dropped params: {', '.join(set(body.keys()) - set(filtered_body.keys()))}")

try:
r = requests.post(
url=url,
json=filtered_body,
headers=headers,
stream=True,
)

r.raise_for_status()
if body["stream"]:
return r.iter_lines()
else:
return r.json()
except Exception as e:
if r:
text = r.text
return f"Error: {e} ({text})"
else:
return f"Error: {e}"
body["user"] = body["user"].get("id", str(body["user"]))

# If it's an o1 model, do a "fake streaming" approach
if is_o1_model(model_id):
# We'll remove "stream" from the body if present (since we'll do manual streaming),
# then filter to the allowed params for o1 models.
body.pop("stream", None)
filtered_body = {k: v for k, v in body.items() if k in allowed_params_o1}

# Log which fields were dropped
if len(body) != len(filtered_body):
dropped_keys = set(body.keys()) - set(filtered_body.keys())
print(f"Dropped params: {', '.join(dropped_keys)}")

try:
# We make a normal request (non-streaming)
r = requests.post(
url=url,
json=filtered_body,
headers=headers,
stream=False,
)
r.raise_for_status()

# Parse the full JSON response
data = r.json()

# Typically, the text content is in data["choices"][0]["message"]["content"]
# This may vary depending on your actual response shape.
# For safety, let's do a little fallback:
content = ""
if (
isinstance(data, dict)
and "choices" in data
and isinstance(data["choices"], list)
and len(data["choices"]) > 0
and "message" in data["choices"][0]
and "content" in data["choices"][0]["message"]
):
content = data["choices"][0]["message"]["content"]
else:
# fallback to something, or just return the raw data
# but let's handle the "fun" streaming of partial content
content = str(data)

# We will chunk the text to simulate streaming
def chunk_text(text: str, chunk_size: int = 30) -> Generator[str, None, None]:
"""Yield text in fixed-size chunks."""
for i in range(0, len(text), chunk_size):
yield text[i : i + chunk_size]

# Return a generator that yields chunks
def fake_stream() -> Generator[str, None, None]:
for chunk in chunk_text(content):
yield chunk

return fake_stream()

except Exception as e:
# If the request object exists, return its text
if "r" in locals() and r is not None:
return f"Error: {e} ({r.text})"
else:
return f"Error: {e}"

else:
# Normal pipeline for non-o1 models:
filtered_body = {k: v for k, v in body.items() if k in allowed_params_default}
if len(body) != len(filtered_body):
dropped_keys = set(body.keys()) - set(filtered_body.keys())
print(f"Dropped params: {', '.join(dropped_keys)}")

try:
r = requests.post(
url=url,
json=filtered_body,
headers=headers,
stream=True,
)
r.raise_for_status()

if filtered_body.get("stream"):
# Real streaming
return r.iter_lines()
else:
# Just return the JSON
return r.json()

except Exception as e:
if "r" in locals() and r is not None:
return f"Error: {e} ({r.text})"
else:
return f"Error: {e}"