Skip to content

Commit

Permalink
Add experimental GLIGEN node
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Apr 29, 2024
1 parent c48cd8b commit 6669c93
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 5 deletions.
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
"LoadResAdapterNormalization": LoadResAdapterNormalization,
"Superprompt": Superprompt,
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
"GLIGENTextBoxApplyBatchCoords": GLIGENTextBoxApplyBatchCoords,
"Intrinsic_lora_sampling": Intrinsic_lora_sampling,

}
Expand Down
118 changes: 115 additions & 3 deletions nodes/curve_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def INPUT_TYPES(cls):
}
}

RETURN_TYPES = ("MASK", "STRING", "FLOAT")
RETURN_TYPES = ("MASK", "STRING", "FLOAT", "INT")
RETURN_NAMES = ("mask", "string", "float", "count")
FUNCTION = "splinedata"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = """
Expand Down Expand Up @@ -126,7 +127,7 @@ def splinedata(self, mask_width, mask_height, coordinates, float_output_type, in
masks_out = torch.stack(mask_tensors)
masks_out = masks_out.repeat(repeat_output, 1, 1, 1)
masks_out = masks_out.mean(dim=-1)
return (masks_out, str(coordinates), out_floats,)
return (masks_out, str(coordinates), out_floats, len(out_floats))

class CreateShapeMaskOnPath:

Expand Down Expand Up @@ -485,4 +486,115 @@ def INPUT_TYPES(s):
"""
def customsigmas(self, float_list):
return torch.tensor(float_list, dtype=torch.float32),
return torch.tensor(float_list, dtype=torch.float32),

class GLIGENTextBoxApplyBatchCoords:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ),
"latents": ("LATENT", ),
"clip": ("CLIP", ),
"gligen_textbox_model": ("GLIGEN", ),
"coordinates": ("STRING", {"forceInput": True}),
"text": ("STRING", {"multiline": True}),
"width": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 8}),
"height": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 8}),
},
}
RETURN_TYPES = ("CONDITIONING", "IMAGE", )
FUNCTION = "append"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = """
Experimental, does not function yet as ComfyUI base changes are needed
"""

def append(self, latents, coordinates, conditioning_to, clip, gligen_textbox_model, text, width, height):
coordinates = json.loads(coordinates.replace("'", '"'))
coordinates = [(coord['x'], coord['y']) for coord in coordinates]

batch_size = sum(tensor.size(0) for tensor in latents.values())
assert len(coordinates) == batch_size, "The number of coordinates does not match the number of latents"
c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)

image_height = latents['samples'].shape[-1] * 8
image_width = latents['samples'].shape[-2] * 8
plot_image_tensor = self.plot_coordinates_to_tensor(coordinates, image_height, image_width, height)

for t in conditioning_to:
n = [t[0], t[1].copy()]

position_params_batch = [[] for _ in range(batch_size)] # Initialize a list of empty lists for each batch item

for i in range(batch_size):
x_position, y_position = coordinates[i]
position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8)
position_params_batch[i].append(position_param) # Append position_param to the correct sublist

prev = []
if "gligen" in n[1]:
prev = n[1]['gligen'][2]
else:
prev = [[] for _ in range(batch_size)]
# Concatenate prev and position_params_batch, ensuring both are lists of lists
# and each sublist corresponds to a batch item
combined_position_params = [prev_item + batch_item for prev_item, batch_item in zip(prev, position_params_batch)]
n[1]['gligen'] = ("position", gligen_textbox_model, combined_position_params)
c.append(n)

return (c, plot_image_tensor,)

def plot_coordinates_to_tensor(self, coordinates, height, width, box_size):
import matplotlib
matplotlib.use('Agg')
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

# Convert coordinates to separate x and y lists
#x_coords, y_coords = zip(*coordinates)

fig, ax = matplotlib.pyplot.subplots(figsize=(width/100, height/100), dpi=100)
#ax.scatter(x_coords, y_coords, color='yellow', label='_nolegend_')

# Draw a box at each coordinate
for x, y in coordinates:
rect = matplotlib.patches.Rectangle((x - box_size/2, y - box_size/2), box_size, box_size,
linewidth=1, edgecolor='green', facecolor='none', alpha=0.5)
ax.add_patch(rect)

# Draw arrows from one point to another to indicate direction
for i in range(len(coordinates) - 1):
x1, y1 = coordinates[i]
x2, y2 = coordinates[i + 1]
ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
arrowprops=dict(arrowstyle="->",
linestyle="-",
lw=1,
color='orange',
mutation_scale=10))
matplotlib.pyplot.rcParams['text.color'] = '#999999'
fig.patch.set_facecolor('#353535')
ax.set_facecolor('#353535')
ax.grid(color='#999999', linestyle='-', linewidth=0.5)
ax.set_xlabel('x', color='#999999')
ax.set_ylabel('y', color='#999999')
for text in ax.get_xticklabels() + ax.get_yticklabels():
text.set_color('#999999')
ax.set_title('Gligen positions')
ax.set_xlabel('X Coordinate')
ax.set_ylabel('Y Coordinate')
ax.legend().remove()
ax.set_xlim(0, width) # Set the x-axis to match the input latent width
ax.set_ylim(height, 0) # Set the y-axis to match the input latent height, with (0,0) at top-left
# Adjust the margins of the subplot
matplotlib.pyplot.subplots_adjust(left=0.08, right=0.95, bottom=0.05, top=0.95, wspace=0.2, hspace=0.2)
canvas = FigureCanvas(fig)
canvas.draw()
matplotlib.pyplot.close(fig)

width, height = fig.get_size_inches() * fig.get_dpi()

image_np = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)
image_tensor = torch.from_numpy(image_np).float() / 255.0
image_tensor = image_tensor.unsqueeze(0)

return image_tensor
4 changes: 2 additions & 2 deletions web/js/spline_editor.js
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ app.registerExtension({
}
});

this.setSize([550, 900]);
this.setSize([550, 920]);
this.resizable = false;
this.splineEditor.parentEl = document.createElement("div");
this.splineEditor.parentEl.className = "spline-editor";
Expand All @@ -190,7 +190,7 @@ app.registerExtension({

chainCallback(this, "onGraphConfigured", function() {
createSplineEditor(this);
this.setSize([550, 900]);
this.setSize([550, 920]);
});

}); // onAfterGraphConfigured
Expand Down

0 comments on commit 6669c93

Please sign in to comment.