jbilcke-hf HF staff commited on
Commit
4ae2e41
1 Parent(s): b0f34ee

look's like we're good

Browse files
Files changed (1) hide show
  1. src/app/interface/generate/index.tsx +209 -55
src/app/interface/generate/index.tsx CHANGED
@@ -8,11 +8,13 @@ import { cn } from "@/lib/utils"
8
  import { headingFont } from "@/app/interface/fonts"
9
  import { useCharacterLimit } from "@/lib/useCharacterLimit"
10
  import { generateAnimation } from "@/app/server/actions/animation"
11
- import { postToCommunity } from "@/app/server/actions/community"
12
  import { useCountdown } from "@/lib/useCountdown"
13
  import { Countdown } from "../countdown"
14
  import { getSDXLModels } from "@/app/server/actions/models"
15
- import { SDXLModel } from "@/types"
 
 
16
 
17
  export function Generate() {
18
  const router = useRouter()
@@ -35,11 +37,13 @@ export function Generate() {
35
  const [showModels, setShowModels] = useState(true)
36
  // useEffect(() => { runsRef.current = runs }, [runs])
37
 
 
 
38
  console.log("runs:", runs)
39
  const { progressPercent, remainingTimeInSec } = useCountdown({
40
  isActive: isLocked,
41
  timerId: runs, // everytime we change this, the timer will reset
42
- durationInSec: 45,
43
  onEnd: () => {}
44
  })
45
 
@@ -78,7 +82,7 @@ export function Generate() {
78
  })
79
 
80
  startTransition(async () => {
81
- const huggingFaceLora = selectedModel ? selectedModel.repo : "KappaNeuro/studio-ghibli-style"
82
  const triggerWord = selectedModel ? selectedModel.trigger_word : "Studio Ghibli Style"
83
 
84
  // now you got a read/write object
@@ -112,9 +116,9 @@ export function Generate() {
112
 
113
  // now you got a read/write object
114
  const current = new URLSearchParams(Array.from(searchParams.entries()))
115
- current.set("postId", post.postId)
116
- current.set("prompt", post.prompt)
117
- current.set("model", post.model)
118
  const search = current.toString()
119
  router.push(`${pathname}${search ? `?${search}` : ""}`)
120
  } catch (err) {
@@ -133,46 +137,135 @@ export function Generate() {
133
  const models = await getSDXLModels()
134
  setModels(models)
135
 
136
- let defaultModel = models.find(model => model.title.toLowerCase().includes("ghibli")) || models[0]
 
137
  if (defaultModel) {
138
  setSelectedModel(defaultModel)
139
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  })
141
  }, [])
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  return (
144
- <div className={cn(
 
 
145
  `fixed inset-0 w-screen h-screen`,
146
  `flex flex-col items-center justify-center`,
147
- `transition-all duration-300 ease-in-out`,
148
- // panel === "play" ? "opacity-1 translate-x-0" : "opacity-0 translate-x-[-1000px] pointer-events-none"
149
- )}>
 
150
  {isLocked ? <Countdown
151
  progressPercent={progressPercent}
152
  remainingTimeInSec={remainingTimeInSec}
153
  /> : null}
154
- <div className={cn(
155
- `flex flex-col md:flex-row`,
 
 
156
  `w-full md:max-w-4xl lg:max-w-5xl xl:max-w-6xl max-h-[80vh]`,
157
- `space-y-3 md:space-y-0 md:space-x-6`,
158
- `transition-all duration-300 ease-in-out`,
159
  )}>
160
- <div
161
- ref={scrollRef}
162
- className={cn(
163
- `flex flex-col`,
164
- `flex-grow rounded-2xl md:rounded-3xl`,
165
- `backdrop-blur-lg bg-white/40`,
166
- `border-2 border-white/10`,
167
- `items-center`,
168
- `space-y-6 md:space-y-8 lg:space-y-12 xl:space-y-16`,
169
- `px-3 py-6 md:px-6 md:py-12 xl:px-8 xl:py-16`,
170
- `overflow-y-scroll`,
171
- )}
172
- style={{
173
- boxShadow: "inset 0 2px 4px 0 rgb(0 0 0 / 0.05)" // TODO: convert to tailwind
174
- }}>
175
 
 
176
  {assetUrl ? <div
177
  className={cn(
178
  `flex flex-col`,
@@ -185,6 +278,7 @@ export function Generate() {
185
  autoPlay
186
  loop
187
  src={assetUrl}
 
188
  /> :
189
  <img
190
  src={assetUrl}
@@ -198,11 +292,11 @@ export function Generate() {
198
  <div className={cn(
199
  `flex flex-col md:flex-row`,
200
  `space-y-3 md:space-y-0 md:space-x-3`,
201
- ` w-full max-w-[1024px]`,
202
  `items-center justify-between`
203
  )}>
204
  <div className={cn(
205
- `flex flex-row flex-grow`
206
  )}>
207
  <input
208
  type="text"
@@ -280,21 +374,35 @@ export function Generate() {
280
  </div>
281
  </div>
282
 
283
- <div className="flex flex-col">
 
 
 
 
 
 
 
 
 
 
 
284
  <div className="flex flex-row">
285
  <h3 className={cn(
286
  headingFont.className,
287
- "text-2xl text-sky-600 mb-4"
288
  )}>{models.length ? "Pick a style:" : "Loading styles.."}</h3>
289
  </div>
290
- <div className="grid grid-cols-4 sm:grid-cols-6 md:grid-cols-10 lg:grid-cols-11 xl:grid-cols-12 gap-2">
291
  {models.map(model =>
292
- <div key={model.repo}
293
- className={isLocked ? '' : `cursor-pointer`}
294
- onClick={() => {
295
- if (!isLocked) { setSelectedModel(model) }
296
- }}>
297
- <img
 
 
 
298
  src={
299
  model.image.startsWith("http")
300
  ? model.image
@@ -304,29 +412,75 @@ export function Generate() {
304
  `transition-all duration-150 ease-in-out`,
305
  `w-20 h-20 object-cover rounded-lg overflow-hidden`,
306
  `border-4 border-transparent`,
307
- `hover:border-yellow-50 hover:scale-110`,
308
  selectedModel?.repo === model.repo
309
  ? `scale-110 border-4 border-yellow-300 hover:border-yellow-300`
310
  : ``
311
  )}
312
  ></img>
313
- </div>)}
 
 
 
 
 
 
 
314
  </div>
315
  </div>
316
 
317
- {/*<div>
318
- <p>Community creations</p>
319
- <div>
320
- <div>A</div>
321
- <div>B</div>
322
- <div>C</div>
323
- <div>D</div>
324
- <div>E</div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  </div>
 
326
  </div>
327
- */}
328
- </div>
329
- </div>
330
  </div>
331
  )
332
  }
 
8
  import { headingFont } from "@/app/interface/fonts"
9
  import { useCharacterLimit } from "@/lib/useCharacterLimit"
10
  import { generateAnimation } from "@/app/server/actions/animation"
11
+ import { getLatestPosts, getPost, postToCommunity } from "@/app/server/actions/community"
12
  import { useCountdown } from "@/lib/useCountdown"
13
  import { Countdown } from "../countdown"
14
  import { getSDXLModels } from "@/app/server/actions/models"
15
+ import { Post, SDXLModel } from "@/types"
16
+ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"
17
+ import { TooltipProvider } from "@radix-ui/react-tooltip"
18
 
19
  export function Generate() {
20
  const router = useRouter()
 
37
  const [showModels, setShowModels] = useState(true)
38
  // useEffect(() => { runsRef.current = runs }, [runs])
39
 
40
+ const [communityRoll, setCommunityRoll] = useState<Post[]>([])
41
+
42
  console.log("runs:", runs)
43
  const { progressPercent, remainingTimeInSec } = useCountdown({
44
  isActive: isLocked,
45
  timerId: runs, // everytime we change this, the timer will reset
46
+ durationInSec: 50, // it usually takes 40 seconds, but there might be lag
47
  onEnd: () => {}
48
  })
49
 
 
82
  })
83
 
84
  startTransition(async () => {
85
+ const huggingFaceLora = selectedModel ? selectedModel.repo.trim() : "KappaNeuro/studio-ghibli-style"
86
  const triggerWord = selectedModel ? selectedModel.trigger_word : "Studio Ghibli Style"
87
 
88
  // now you got a read/write object
 
116
 
117
  // now you got a read/write object
118
  const current = new URLSearchParams(Array.from(searchParams.entries()))
119
+ current.set("postId", post.postId.trim())
120
+ current.set("prompt", post.prompt.trim())
121
+ current.set("model", post.model.trim())
122
  const search = current.toString()
123
  router.push(`${pathname}${search ? `?${search}` : ""}`)
124
  } catch (err) {
 
137
  const models = await getSDXLModels()
138
  setModels(models)
139
 
140
+ const defaultModel = models.find(model => model.title.toLowerCase().includes("ghibli")) || models[0]
141
+
142
  if (defaultModel) {
143
  setSelectedModel(defaultModel)
144
  }
145
+
146
+ // now we load URL params
147
+ const current = new URLSearchParams(Array.from(searchParams.entries()))
148
+
149
+ // URL query params
150
+ const existingPostId = current.get("postId") || ""
151
+ const existingPrompt = current.get("prompt")?.trim() || ""
152
+ const existingModelName = current.get("model")?.toLowerCase().trim() || ""
153
+
154
+ // if and only if we don't have a post id, then we look at the other query params
155
+ if (existingPrompt) {
156
+ setPromptDraft(existingPrompt)
157
+ }
158
+
159
+ if (existingModelName) {
160
+ let existingModel = models.find(model => model.title.toLowerCase().trim().includes(existingModelName))
161
+ if (existingModel) {
162
+ setSelectedModel(existingModel)
163
+ }
164
+ }
165
+
166
+ // if we have a post id, then we use that to override all the previous values
167
+ if (existingPostId) {
168
+ try {
169
+ const post = await getPost(existingPostId)
170
+
171
+ if (post.assetUrl) {
172
+ setAssetUrl(post.assetUrl)
173
+ }
174
+ if (post.prompt) {
175
+ setPromptDraft(post.prompt)
176
+ }
177
+
178
+ if (post.model) {
179
+ const existingModel = models.find(model => model.title.toLowerCase().trim().includes(post.model.toLowerCase().trim()))
180
+
181
+ if (existingModel) {
182
+ setSelectedModel(existingModel)
183
+ }
184
+ }
185
+ } catch (err) {
186
+ console.error(`failed to load the community post (${err})`)
187
+ }
188
+ }
189
+ })
190
+ }, [])
191
+
192
+ useEffect(() => {
193
+ startTransition(async () => {
194
+ const posts = await getLatestPosts({
195
+ maxNbPosts: 16
196
+ })
197
+ if (posts?.length) {
198
+ setCommunityRoll(posts)
199
+ }
200
  })
201
  }, [])
202
 
203
+ const handleSelectCommunityPost = (post: Post) => {
204
+ if (isLocked) { return }
205
+
206
+ scrollRef.current?.scroll({
207
+ top: 0,
208
+ behavior: 'smooth'
209
+ })
210
+
211
+ // now you got a read/write object
212
+ const current = new URLSearchParams(Array.from(searchParams.entries()))
213
+ current.set("postId", post.postId.trim())
214
+ current.set("prompt", post.prompt.trim())
215
+ current.set("model", post.model.trim())
216
+ const search = current.toString()
217
+ router.push(`${pathname}${search ? `?${search}` : ""}`)
218
+
219
+ if (post.assetUrl) {
220
+ setAssetUrl(post.assetUrl)
221
+ }
222
+ if (post.prompt) {
223
+ setPromptDraft(post.prompt)
224
+ }
225
+
226
+ if (post.model) {
227
+ const existingModel = models.find(model => model.title.toLowerCase().trim().includes(post.model.toLowerCase().trim()))
228
+
229
+ if (existingModel) {
230
+ setSelectedModel(existingModel)
231
+ }
232
+ }
233
+ }
234
+
235
  return (
236
+ <div
237
+ ref={scrollRef}
238
+ className={cn(
239
  `fixed inset-0 w-screen h-screen`,
240
  `flex flex-col items-center justify-center`,
241
+ // `transition-all duration-300 ease-in-out`,
242
+ `overflow-y-scroll`,
243
+ )}>
244
+ <TooltipProvider>
245
  {isLocked ? <Countdown
246
  progressPercent={progressPercent}
247
  remainingTimeInSec={remainingTimeInSec}
248
  /> : null}
249
+ <div
250
+
251
+ className={cn(
252
+ `flex flex-col`,
253
  `w-full md:max-w-4xl lg:max-w-5xl xl:max-w-6xl max-h-[80vh]`,
254
+ `space-y-8`,
255
+ // `transition-all duration-300 ease-in-out`,
256
  )}>
257
+
258
+ <div
259
+ className={cn(
260
+ `flex flex-col`,
261
+ `flex-grow rounded-2xl md:rounded-3xl`,
262
+ `backdrop-blur-lg bg-white/40`,
263
+ `border-2 border-white/10`,
264
+ `items-center`,
265
+ `space-y-6 md:space-y-8 lg:space-y-12 xl:space-y-16`,
266
+ `px-3 py-6 md:px-6 md:py-12 xl:px-8 xl:py-16`,
 
 
 
 
 
267
 
268
+ )}>
269
  {assetUrl ? <div
270
  className={cn(
271
  `flex flex-col`,
 
278
  autoPlay
279
  loop
280
  src={assetUrl}
281
+ className="rounded-md overflow-hidden"
282
  /> :
283
  <img
284
  src={assetUrl}
 
292
  <div className={cn(
293
  `flex flex-col md:flex-row`,
294
  `space-y-3 md:space-y-0 md:space-x-3`,
295
+ ` w-full md:max-w-[1024px]`,
296
  `items-center justify-between`
297
  )}>
298
  <div className={cn(
299
+ `flex flex-row flex-grow w-full`
300
  )}>
301
  <input
302
  type="text"
 
374
  </div>
375
  </div>
376
 
377
+ </div>
378
+
379
+ <div
380
+ className={cn(
381
+ `flex flex-col`,
382
+ `flex-grow rounded-2xl md:rounded-3xl`,
383
+ `backdrop-blur-lg bg-white/40`,
384
+ `border-2 border-white/10`,
385
+ `items-center`,
386
+ `space-y-2 md:space-y-3 lg:space-y-4 xl:space-y-6`,
387
+ `px-3 py-6 md:px-6 md:py-12 xl:px-8 xl:py-16`,
388
+ )}>
389
  <div className="flex flex-row">
390
  <h3 className={cn(
391
  headingFont.className,
392
+ "text-4xl text-sky-600 mb-4"
393
  )}>{models.length ? "Pick a style:" : "Loading styles.."}</h3>
394
  </div>
395
+ <div className="grid grid-cols-4 sm:grid-cols-6 md:grid-cols-8 lg:grid-cols-10 xl:grid-cols-12 gap-2">
396
  {models.map(model =>
397
+ <div key={model.repo}>
398
+ <Tooltip>
399
+ <TooltipTrigger asChild>
400
+ <div
401
+ className={isLocked ? 'cursor-not-allowed' : `cursor-pointer`}
402
+ onClick={() => {
403
+ if (!isLocked) { setSelectedModel(model) }
404
+ }}>
405
+ <img
406
  src={
407
  model.image.startsWith("http")
408
  ? model.image
 
412
  `transition-all duration-150 ease-in-out`,
413
  `w-20 h-20 object-cover rounded-lg overflow-hidden`,
414
  `border-4 border-transparent`,
415
+ isLocked ? '' : `hover:border-yellow-50 hover:scale-110`,
416
  selectedModel?.repo === model.repo
417
  ? `scale-110 border-4 border-yellow-300 hover:border-yellow-300`
418
  : ``
419
  )}
420
  ></img>
421
+ </div>
422
+ </TooltipTrigger>
423
+ {!isLocked && <TooltipContent>
424
+ <p className="w-full max-w-xl">{model.title}</p>
425
+ </TooltipContent>}
426
+ </Tooltip>
427
+ </div>
428
+ )}
429
  </div>
430
  </div>
431
 
432
+
433
+ <div
434
+ className={cn(
435
+ `flex flex-col`,
436
+ `flex-grow rounded-2xl md:rounded-3xl`,
437
+ `backdrop-blur-lg bg-white/40`,
438
+ `border-2 border-white/10`,
439
+ `items-center`,
440
+ `space-y-2 md:space-y-3 lg:space-y-4 xl:space-y-6`,
441
+ `px-3 py-6 md:px-6 md:py-12 xl:px-8 xl:py-16`,
442
+ )}>
443
+ <div className="flex flex-row">
444
+ <h3 className={cn(
445
+ headingFont.className,
446
+ "text-4xl text-sky-600 mb-4"
447
+ )}>{communityRoll.length ? "Community Roll:" : "Loading community toll.."}</h3>
448
+ </div>
449
+ <div className="grid grid-cols-1 sm:grid-cols-3 md:grid-cols-4 lg:grid-cols-6 xl:grid-cols-8 gap-2">
450
+ {communityRoll.map(post =>
451
+ <div key={post.postId}>
452
+ <Tooltip>
453
+ <TooltipTrigger asChild>
454
+ <div
455
+ className={isLocked ? 'cursor-not-allowed' : `cursor-pointer`}
456
+ onClick={() => { handleSelectCommunityPost(post) }}>
457
+ <video
458
+ muted
459
+ autoPlay
460
+ loop
461
+ src={post.assetUrl}
462
+ className={cn(
463
+ `rounded-md overflow-hidden`,
464
+ `transition-all duration-150 ease-in-out`,
465
+ `w-40 h-30 object-cover rounded-lg overflow-hidden`,
466
+ `border-4 border-transparent`,
467
+ isLocked ? '' : `hover:border-yellow-50 hover:scale-110`,
468
+ )}
469
+ />
470
+ </div>
471
+ </TooltipTrigger>
472
+ {!isLocked && <TooltipContent>
473
+ <p className="w-full max-w-xl">{post.prompt}</p>
474
+ </TooltipContent>}
475
+ </Tooltip>
476
+ </div>
477
+ )}
478
+ </div>
479
  </div>
480
+
481
  </div>
482
+
483
+ </TooltipProvider>
 
484
  </div>
485
  )
486
  }