Skip to content

Commit

Permalink
Add fill mode support and generally improve inpainting
Browse files Browse the repository at this point in the history
  • Loading branch information
hafriedlander committed Jul 15, 2023
1 parent 179edff commit cd3e6ab
Show file tree
Hide file tree
Showing 13 changed files with 705 additions and 413 deletions.
2 changes: 1 addition & 1 deletion api-interfaces
1 change: 0 additions & 1 deletion gyre/config/models/samhq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
- id: "samhq"
task: "mask-predict"
enabled: True
default: True
class: "SamHQPipeline"
model: "@empty"
overrides:
Expand Down
292 changes: 149 additions & 143 deletions gyre/generated/generation_pb2.py

Large diffs are not rendered by default.

80 changes: 74 additions & 6 deletions gyre/generated/generation_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,29 @@ POINT_BACKGROUND: LOIPointLabel.ValueType # 0
POINT_FOREGROUND: LOIPointLabel.ValueType # 1
global___LOIPointLabel = LOIPointLabel

class _InpaintFillMode:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType

class _InpaintFillModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_InpaintFillMode.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
INPAINT_FILL_AUTO: _InpaintFillMode.ValueType # 0
INPAINT_FILL_NONE: _InpaintFillMode.ValueType # 1
INPAINT_FILL_SHUFFLE: _InpaintFillMode.ValueType # 2
INPAINT_FILL_REPEAT: _InpaintFillMode.ValueType # 3
INPAINT_FILL_AI: _InpaintFillMode.ValueType # 4
INPAINT_FILL_NOISE: _InpaintFillMode.ValueType # 5

class InpaintFillMode(_InpaintFillMode, metaclass=_InpaintFillModeEnumTypeWrapper): ...

INPAINT_FILL_AUTO: InpaintFillMode.ValueType # 0
INPAINT_FILL_NONE: InpaintFillMode.ValueType # 1
INPAINT_FILL_SHUFFLE: InpaintFillMode.ValueType # 2
INPAINT_FILL_REPEAT: InpaintFillMode.ValueType # 3
INPAINT_FILL_AI: InpaintFillMode.ValueType # 4
INPAINT_FILL_NOISE: InpaintFillMode.ValueType # 5
global___InpaintFillMode = InpaintFillMode

class _DiffusionSampler:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
Expand Down Expand Up @@ -1032,6 +1055,23 @@ class ImageAdjustment_MaskReuse(google.protobuf.message.Message):

global___ImageAdjustment_MaskReuse = ImageAdjustment_MaskReuse

@typing_extensions.final
class ImageAdjustment_MaskSoftDilate(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

SIGMA_FIELD_NUMBER: builtins.int
sigma: builtins.int
def __init__(
self,
*,
sigma: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_sigma", b"_sigma", "sigma", b"sigma"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_sigma", b"_sigma", "sigma", b"sigma"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["_sigma", b"_sigma"]) -> typing_extensions.Literal["sigma"] | None: ...

global___ImageAdjustment_MaskSoftDilate = ImageAdjustment_MaskSoftDilate

@typing_extensions.final
class ImageAdjustment(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
Expand All @@ -1056,6 +1096,7 @@ class ImageAdjustment(google.protobuf.message.Message):
SHUFFLE_FIELD_NUMBER: builtins.int
MASK_PREDICT_FIELD_NUMBER: builtins.int
MASK_REUSE_FIELD_NUMBER: builtins.int
MASK_SOFT_DILATE_FIELD_NUMBER: builtins.int
ENGINE_ID_FIELD_NUMBER: builtins.int
@property
def blur(self) -> global___ImageAdjustment_Gaussian: ...
Expand Down Expand Up @@ -1097,6 +1138,8 @@ class ImageAdjustment(google.protobuf.message.Message):
def mask_predict(self) -> global___ImageAdjustment_MaskPredict: ...
@property
def mask_reuse(self) -> global___ImageAdjustment_MaskReuse: ...
@property
def mask_soft_dilate(self) -> global___ImageAdjustment_MaskSoftDilate: ...
engine_id: builtins.str
def __init__(
self,
Expand All @@ -1121,14 +1164,15 @@ class ImageAdjustment(google.protobuf.message.Message):
shuffle: global___ImageAdjustment_Shuffle | None = ...,
mask_predict: global___ImageAdjustment_MaskPredict | None = ...,
mask_reuse: global___ImageAdjustment_MaskReuse | None = ...,
mask_soft_dilate: global___ImageAdjustment_MaskSoftDilate | None = ...,
engine_id: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_engine_id", b"_engine_id", "adjustment", b"adjustment", "autoscale", b"autoscale", "background_removal", b"background_removal", "blur", b"blur", "canny_edge", b"canny_edge", "channels", b"channels", "crop", b"crop", "depth", b"depth", "edge_detection", b"edge_detection", "engine_id", b"engine_id", "invert", b"invert", "keypose", b"keypose", "levels", b"levels", "mask_predict", b"mask_predict", "mask_reuse", b"mask_reuse", "normal", b"normal", "openpose", b"openpose", "palletize", b"palletize", "quantize", b"quantize", "rescale", b"rescale", "segmentation", b"segmentation", "shuffle", b"shuffle"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_engine_id", b"_engine_id", "adjustment", b"adjustment", "autoscale", b"autoscale", "background_removal", b"background_removal", "blur", b"blur", "canny_edge", b"canny_edge", "channels", b"channels", "crop", b"crop", "depth", b"depth", "edge_detection", b"edge_detection", "engine_id", b"engine_id", "invert", b"invert", "keypose", b"keypose", "levels", b"levels", "mask_predict", b"mask_predict", "mask_reuse", b"mask_reuse", "normal", b"normal", "openpose", b"openpose", "palletize", b"palletize", "quantize", b"quantize", "rescale", b"rescale", "segmentation", b"segmentation", "shuffle", b"shuffle"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_engine_id", b"_engine_id", "adjustment", b"adjustment", "autoscale", b"autoscale", "background_removal", b"background_removal", "blur", b"blur", "canny_edge", b"canny_edge", "channels", b"channels", "crop", b"crop", "depth", b"depth", "edge_detection", b"edge_detection", "engine_id", b"engine_id", "invert", b"invert", "keypose", b"keypose", "levels", b"levels", "mask_predict", b"mask_predict", "mask_reuse", b"mask_reuse", "mask_soft_dilate", b"mask_soft_dilate", "normal", b"normal", "openpose", b"openpose", "palletize", b"palletize", "quantize", b"quantize", "rescale", b"rescale", "segmentation", b"segmentation", "shuffle", b"shuffle"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_engine_id", b"_engine_id", "adjustment", b"adjustment", "autoscale", b"autoscale", "background_removal", b"background_removal", "blur", b"blur", "canny_edge", b"canny_edge", "channels", b"channels", "crop", b"crop", "depth", b"depth", "edge_detection", b"edge_detection", "engine_id", b"engine_id", "invert", b"invert", "keypose", b"keypose", "levels", b"levels", "mask_predict", b"mask_predict", "mask_reuse", b"mask_reuse", "mask_soft_dilate", b"mask_soft_dilate", "normal", b"normal", "openpose", b"openpose", "palletize", b"palletize", "quantize", b"quantize", "rescale", b"rescale", "segmentation", b"segmentation", "shuffle", b"shuffle"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_engine_id", b"_engine_id"]) -> typing_extensions.Literal["engine_id"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["adjustment", b"adjustment"]) -> typing_extensions.Literal["blur", "invert", "levels", "channels", "rescale", "crop", "depth", "canny_edge", "edge_detection", "segmentation", "keypose", "openpose", "normal", "background_removal", "autoscale", "palletize", "quantize", "shuffle", "mask_predict", "mask_reuse"] | None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["adjustment", b"adjustment"]) -> typing_extensions.Literal["blur", "invert", "levels", "channels", "rescale", "crop", "depth", "canny_edge", "edge_detection", "segmentation", "keypose", "openpose", "normal", "background_removal", "autoscale", "palletize", "quantize", "shuffle", "mask_predict", "mask_reuse", "mask_soft_dilate"] | None: ...

global___ImageAdjustment = ImageAdjustment

Expand Down Expand Up @@ -1536,6 +1580,23 @@ class LocationsOfInterest(google.protobuf.message.Message):

global___LocationsOfInterest = LocationsOfInterest

@typing_extensions.final
class InpaintParameters(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

FILL_MODE_FIELD_NUMBER: builtins.int
fill_mode: global___InpaintFillMode.ValueType
def __init__(
self,
*,
fill_mode: global___InpaintFillMode.ValueType | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_fill_mode", b"_fill_mode", "fill_mode", b"fill_mode"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_fill_mode", b"_fill_mode", "fill_mode", b"fill_mode"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["_fill_mode", b"_fill_mode"]) -> typing_extensions.Literal["fill_mode"] | None: ...

global___InpaintParameters = InpaintParameters

@typing_extensions.final
class PromptParameters(google.protobuf.message.Message):
"""A set of parameters for each individual Prompt."""
Expand All @@ -1548,6 +1609,7 @@ class PromptParameters(google.protobuf.message.Message):
TOKEN_OVERRIDES_FIELD_NUMBER: builtins.int
CLIP_LAYER_FIELD_NUMBER: builtins.int
HINT_PRIORITY_FIELD_NUMBER: builtins.int
INPAINT_PARAMETERS_FIELD_NUMBER: builtins.int
init: builtins.bool
weight: builtins.float
@property
Expand All @@ -1561,7 +1623,10 @@ class PromptParameters(google.protobuf.message.Message):
0 _or_ 1 == final, 2 = penultimate, 3 = next
"""
hint_priority: global___HintPriority.ValueType
"""Soecify the application mode for hints"""
"""Specify the application mode for hints"""
@property
def inpaint_parameters(self) -> global___InpaintParameters:
"""Specify the inpaint controls for inpainting"""
def __init__(
self,
*,
Expand All @@ -1571,16 +1636,19 @@ class PromptParameters(google.protobuf.message.Message):
token_overrides: collections.abc.Iterable[global___TokenOverride] | None = ...,
clip_layer: builtins.int | None = ...,
hint_priority: global___HintPriority.ValueType | None = ...,
inpaint_parameters: global___InpaintParameters | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_clip_layer", b"_clip_layer", "_hint_priority", b"_hint_priority", "_init", b"_init", "_weight", b"_weight", "clip_layer", b"clip_layer", "hint_priority", b"hint_priority", "init", b"init", "weight", b"weight"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_clip_layer", b"_clip_layer", "_hint_priority", b"_hint_priority", "_init", b"_init", "_weight", b"_weight", "clip_layer", b"clip_layer", "hint_priority", b"hint_priority", "init", b"init", "named_weights", b"named_weights", "token_overrides", b"token_overrides", "weight", b"weight"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_clip_layer", b"_clip_layer", "_hint_priority", b"_hint_priority", "_init", b"_init", "_inpaint_parameters", b"_inpaint_parameters", "_weight", b"_weight", "clip_layer", b"clip_layer", "hint_priority", b"hint_priority", "init", b"init", "inpaint_parameters", b"inpaint_parameters", "weight", b"weight"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_clip_layer", b"_clip_layer", "_hint_priority", b"_hint_priority", "_init", b"_init", "_inpaint_parameters", b"_inpaint_parameters", "_weight", b"_weight", "clip_layer", b"clip_layer", "hint_priority", b"hint_priority", "init", b"init", "inpaint_parameters", b"inpaint_parameters", "named_weights", b"named_weights", "token_overrides", b"token_overrides", "weight", b"weight"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_clip_layer", b"_clip_layer"]) -> typing_extensions.Literal["clip_layer"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_hint_priority", b"_hint_priority"]) -> typing_extensions.Literal["hint_priority"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_init", b"_init"]) -> typing_extensions.Literal["init"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_inpaint_parameters", b"_inpaint_parameters"]) -> typing_extensions.Literal["inpaint_parameters"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_weight", b"_weight"]) -> typing_extensions.Literal["weight"] | None: ...

global___PromptParameters = PromptParameters
Expand Down
35 changes: 35 additions & 0 deletions gyre/generated/stablecabal.openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,9 @@
"maskReuse": {
"$ref": "#/components/schemas/ImageAdjustment_MaskReuse"
},
"maskSoftDilate": {
"$ref": "#/components/schemas/ImageAdjustment_MaskSoftDilate"
},
"engineId": {
"type": "string"
}
Expand Down Expand Up @@ -1983,6 +1986,15 @@
},
"title": "Reuse the most recently predicted mask"
},
"ImageAdjustment_MaskSoftDilate": {
"type": "object",
"properties": {
"sigma": {
"type": "string",
"format": "uint64"
}
}
},
"ImageAdjustment_Normal": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -2128,6 +2140,26 @@
}
}
},
"InpaintFillMode": {
"type": "string",
"enum": [
"INPAINT_FILL_AUTO",
"INPAINT_FILL_NONE",
"INPAINT_FILL_SHUFFLE",
"INPAINT_FILL_REPEAT",
"INPAINT_FILL_AI",
"INPAINT_FILL_NOISE"
],
"default": "INPAINT_FILL_AUTO"
},
"InpaintParameters": {
"type": "object",
"properties": {
"fillMode": {
"$ref": "#/components/schemas/InpaintFillMode"
}
}
},
"LOIPoint": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -2371,6 +2403,9 @@
},
"hintPriority": {
"$ref": "#/components/schemas/HintPriority"
},
"inpaintParameters": {
"$ref": "#/components/schemas/InpaintParameters"
}
},
"description": "A set of parameters for each individual Prompt."
Expand Down
38 changes: 37 additions & 1 deletion gyre/generated/stablecabal.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,9 @@
"maskReuse": {
"$ref": "#/definitions/ImageAdjustment_MaskReuse"
},
"maskSoftDilate": {
"$ref": "#/definitions/ImageAdjustment_MaskSoftDilate"
},
"engineId": {
"type": "string"
}
Expand Down Expand Up @@ -1831,6 +1834,15 @@
},
"title": "Reuse the most recently predicted mask"
},
"ImageAdjustment_MaskSoftDilate": {
"type": "object",
"properties": {
"sigma": {
"type": "string",
"format": "uint64"
}
}
},
"ImageAdjustment_Normal": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -1978,6 +1990,26 @@
}
}
},
"InpaintFillMode": {
"type": "string",
"enum": [
"INPAINT_FILL_AUTO",
"INPAINT_FILL_NONE",
"INPAINT_FILL_SHUFFLE",
"INPAINT_FILL_REPEAT",
"INPAINT_FILL_AI",
"INPAINT_FILL_NOISE"
],
"default": "INPAINT_FILL_AUTO"
},
"InpaintParameters": {
"type": "object",
"properties": {
"fillMode": {
"$ref": "#/definitions/InpaintFillMode"
}
}
},
"LOIPoint": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -2221,7 +2253,11 @@
},
"hintPriority": {
"$ref": "#/definitions/HintPriority",
"title": "Soecify the application mode for hints"
"title": "Specify the application mode for hints"
},
"inpaintParameters": {
"$ref": "#/definitions/InpaintParameters",
"title": "Specify the inpaint controls for inpainting"
}
},
"description": "A set of parameters for each individual Prompt."
Expand Down
58 changes: 58 additions & 0 deletions gyre/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,8 @@ def _return_pipeline_to_pool(self, slot):
pipeline = slot.pipeline

# Deactivate and remove it from the slot
slot.pipeline.subslot.deactivate()
slot.pipeline.subslot = None
slot.pipeline.deactivate()
slot.pipeline = None

Expand All @@ -2154,6 +2156,7 @@ def _get_pipeline_from_pool(self, slot, id):

# Assign the pipeline to the slot and activate
slot.pipeline = pipeline
slot.pipeline.subslot = SubSlot(self, slot)
slot.pipeline.activate(slot.device)

return pipeline
Expand Down Expand Up @@ -2204,6 +2207,7 @@ def with_engine(self, id=None, task=None):
if not slot.pipeline:
existing = False
slot.pipeline = self._build_pipeline_for_engine(spec)
slot.pipeline.subslot = SubSlot(self, slot)
slot.pipeline.activate(slot.device)

if self._ram_monitor:
Expand All @@ -2219,3 +2223,57 @@ def with_engine(self, id=None, task=None):
self._device_queue.put(slot)

# All done


class SubSlot:
def __init__(self, manager, slot):
self.manager = manager
self.superslot = slot
self.device = slot.device
self.pipeline = None

def deactivate(self):
if self.pipeline:
self.manager._return_pipeline_to_pool(self)

@contextmanager
def __call__(self, id=None, task=None):
# TODO: This is all duplicated from with_engine

if id is None:
id = self.manager._defaults[task if task else "generate"]

if id is None:
raise EngineNotFoundError("No engine ID provided and no default is set.")

# Get the engine spec
spec = self.manager._find_spec(id=id)
if not spec or not spec.enabled:
raise EngineNotFoundError(f"Engine ID {id} doesn't exist or isn't enabled.")

if task is not None and task != spec.task:
raise ValueError(f"Engine ID {id} is for task '{spec.task}' not '{task}'")

try:
# Get pipeline (create if all pipelines for the id are busy)

# If a pipeline is already active on this device slot, check if it's the right
# one. If not, deactivate it and clear it
if self.pipeline and self.pipeline.id != id:
self.manager._return_pipeline_to_pool(self)

# If there's no pipeline on this device slot yet, find it (creating it
# if all the existing pipelines are busy)
if not self.pipeline:
self.manager._get_pipeline_from_pool(self, id)

if not self.pipeline:
self.pipeline = self.manager._build_pipeline_for_engine(spec)
self.pipeline.with_subengine = SubSlot(self.manager, self)
self.pipeline.subslot = SubSlot(self.manager, self)
self.pipeline.activate(self.device)

# Do the work
yield self.pipeline
finally:
pass
Loading

0 comments on commit cd3e6ab

Please sign in to comment.